# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interactions: Self-consistent D4 Dispersion
===========================================
Self-consistent D4 dispersion correction.
"""
from __future__ import annotations
from typing import Any
import tad_dftd4 as d4
import tad_dftd4.defaults as d4_defaults
import torch
from tad_mctc.data import PAULING
from tad_mctc.exceptions import DeviceError
from tad_mctc.math import einsum
from tad_mctc.ncoord import coordination_number, erf_count
from tad_mctc.typing import (
DD,
CountingFunction,
Tensor,
TensorLike,
get_default_dtype,
override,
)
from dxtb import IndexHelper
from dxtb._src.param import Param, ParamModule
from dxtb._src.typing import Slicers
from ..base import Interaction, InteractionCache
__all__ = ["DispersionD4SC", "LABEL_DISPERSIOND4SC", "new_d4sc"]
LABEL_DISPERSIOND4SC = "DispersionD4SC"
"""Label for the :class:`.DispersionD4SC` interaction, coinciding with the class name."""
class DispersionD4SCCache(InteractionCache, TensorLike):
"""
Restart data for the :class:`.DispersionD4SC` interaction.
Note
----
The dispersion parameters (a1, a2, ...) are given in the dispersion
class constructor.
"""
__store: Store | None
"""Storage for cache (required for culling)."""
cn: Tensor
"""Coordination number of every atom."""
dispmat: Tensor
"""
Dispersion matrix. This quantity is almost equal to the dispersion energy,
except for multiplication with C6 and C8.
"""
model: d4.model.D4Model
"""
Model for the D4 dispersion correction.
Same object as in the `.DispersionD4SC` class.
"""
__slots__ = ["__store", "cn", "dispmat", "model"]
def __init__(
self,
cn: Tensor,
dispmat: Tensor,
model: d4.model.D4Model,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
device=device if device is None else cn.device,
dtype=dtype if dtype is None else cn.dtype,
)
self.cn = cn
self.dispmat = dispmat
# Model is the same object as in the DispersionD4SC class
self.model = model
self.__store = None
class Store:
"""
Storage container for cache containing ``__slots__`` before culling.
"""
cn: Tensor
"""Coordination number of every atom."""
dispmat: Tensor
"""
Dispersion matrix. This quantity is almost equal to the dispersion
energy, except for multiplication with C6 and C8.
"""
def __init__(
self, cn: Tensor, dispmat: Tensor, model: d4.model.D4Model
) -> None:
self.cn = cn
self.dispmat = dispmat
# only store numbers to get a different object
self.numbers = model.numbers
def cull(self, conv: Tensor, slicers: Slicers) -> None:
if self.__store is None:
self.__store = self.Store(self.cn, self.dispmat, self.model)
slicer = slicers["atom"]
self.cn = self.cn[tuple([~conv, *slicer])]
self.dispmat = self.dispmat[tuple([~conv, *slicer, *slicer])]
self.model.numbers = self.model.numbers[tuple([~conv, *slicer])]
def restore(self) -> None:
if self.__store is None:
raise RuntimeError("Nothing to restore. Store is empty.")
self.cn = self.__store.cn
self.dispmat = self.__store.dispmat
self.model.numbers = self.__store.numbers
[docs]
class DispersionD4SC(Interaction):
"""
Self-consistent D4 dispersion correction (:class:`.DispersionD4SC`).
"""
param: d4.Param
"""Dispersion parameters."""
model: d4.model.D4Model
"""Model for the D4 dispersion correction."""
rcov: Tensor
"""Covalent radii of all atoms."""
r4r2: Tensor
"""R4/R2 ratio of all atoms."""
cutoff: d4.cutoff.Cutoff
"""Real-space cutoff for the D4 dispersion correction."""
counting_function: CountingFunction
"""
Counting function for the coordination number.
:default: :func:`tad_mctc.ncoord.erf_count`
"""
damping_function: d4.damping.Damping
"""
Damping function for the dispersion correction.
:default: :func:`d4.damping.RationalDamping`
"""
__slots__ = [
"param",
"model",
"rcov",
"r4r2",
"cutoff",
"counting_function",
"damping_function",
]
def __init__(
self,
param: d4.Param,
model: d4.model.D4Model,
rcov: Tensor,
r4r2: Tensor,
cutoff: d4.cutoff.Cutoff,
counting_function: CountingFunction = erf_count,
damping_function: d4.damping.Damping = d4.damping.RationalDamping(),
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(device, dtype)
self.param = param
self.model = model
self.rcov = rcov
self.r4r2 = r4r2
self.cutoff = cutoff
self.counting_function = counting_function
self.damping_function = damping_function
# pylint: disable=unused-argument
[docs]
@override
def get_cache(
self,
*,
numbers: Tensor | None = None,
positions: Tensor | None = None,
ihelp: IndexHelper | None = None,
**_,
) -> DispersionD4SCCache:
"""
Create restart data for individual interactions.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
ihelp : IndexHelper
Index mapping for the basis set.
Returns
-------
DispersionD4SCCache
Restart data for the interaction.
Note
----
If the :class:`.DispersionD4SC` interaction is evaluated within the
:class:`dxtb.components.InteractionList`, ``positions`` will be passed
as an argument, too. Hence, it is necessary to absorb the ``positions``
in the signature of the function (also see
:meth:`dxtb.components.Interaction.get_cache`).
"""
if numbers is None:
raise ValueError(
"Atomic numbers are required for DispersionD4SC cache."
)
if positions is None:
raise ValueError("Positions are required for ES2 cache.")
cachvars = (numbers.detach().clone(), positions.detach().clone())
if self.cache_is_latest(cachvars) is True:
if not isinstance(self.cache, DispersionD4SCCache):
raise TypeError(
f"Cache in {self.label} is not of type '{self.label}."
"Cache'. This can only happen if you manually manipulate "
"the cache."
)
return self.cache
# if the cache is built, store the cachevar for validation
self._cachevars = cachvars
en = PAULING(**self.dd)[numbers]
endiff = torch.abs(en.unsqueeze(-2) - en.unsqueeze(-1))
weight = d4_defaults.D4_K4 * torch.exp(
-((endiff + d4_defaults.D4_K5) ** 2.0) / d4_defaults.D4_K6
)
cn = coordination_number(
numbers,
positions,
counting_function=self.counting_function,
rcov=self.rcov,
cutoff=self.cutoff.cn,
pair_weight=weight,
)
# tblite: disp/d4.f90::get_dispersion_matrix
# Instead of multiplying with C6 (from `get_atomic_c6`), we multiply
# with the reference C6 coefficients that have not been multiplied with
# the Gaussian weights yet. Correspondingly, we have to set the C6
# argument of `dispersion2` to 1.
edisp = d4.dispersion.dispersion2(
numbers,
positions,
self.param,
torch.ones((*numbers.shape, numbers.shape[-1]), **self.dd),
self.r4r2,
as_matrix=True,
)
dispmat = edisp.unsqueeze(-1).unsqueeze(-1) * self.model.rc6
self.cache = DispersionD4SCCache(cn, dispmat, self.model)
return self.cache
[docs]
@override
def get_monopole_atom_energy(
self, cache: InteractionCache, qat: Tensor, **_: Any
) -> Tensor:
"""
Calculate the D4 dispersion correction energy.
Parameters
----------
cache : DispersionD4SCCache
Restart data for the interaction.
qat : Tensor
Atomic charges of all atoms.
Returns
-------
Tensor
Atomwise D4 dispersion correction energies.
"""
if not isinstance(cache, DispersionD4SCCache):
raise TypeError(
f"Cache in {self.label} is not of type 'DispersionD4SCCache'."
)
# `numbers` in model are updated in cache (for culling)
weights = self.model.weight_references(cache.cn, qat)
return 0.5 * einsum(
"...ijab,...ia,...jb->...j",
*(cache.dispmat, weights, weights),
optimize=[(0, 1), (0, 1)],
)
[docs]
@override
def get_monopole_atom_potential(
self, cache: InteractionCache, qat: Tensor, *_: Any, **__: Any
) -> Tensor:
"""
Calculate the D4 dispersion correction potential.
Parameters
----------
cache : DispersionD4SCCache
Restart data for the interaction.
qat : Tensor
Atomic charges of all atoms.
Returns
-------
Tensor
Atomwise dispersion correction potential.
"""
if not isinstance(cache, DispersionD4SCCache):
raise TypeError(
f"Cache in {self.label} is not of type 'DispersionD4SCCache'."
)
weights, dgwdq = self.model.weight_references(
cache.cn, qat, with_dgwdq=True
)
return einsum(
"...ijab,...jb,...ia->...i", cache.dispmat, weights, dgwdq
)
[docs]
def new_d4sc(
numbers: Tensor,
par: Param | ParamModule,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> DispersionD4SC | None:
"""
Create new instance of :class:`.DispersionD4SC`.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
par : Param | ParamModule
Representation of an extended tight-binding model.
Returns
-------
DispersionD4SC | None
Instance of the :class:`.DispersionD4SC` class or ``None`` if
no :class:`.DispersionD4SC` is used.
"""
dd: DD = {
"device": device,
"dtype": dtype if dtype is not None else get_default_dtype(),
}
# compatibility with previous version based on `Param`
if not isinstance(par, ParamModule):
par = ParamModule(par, **dd)
if "dispersion" not in par or par.is_none("dispersion"):
return None
if par.is_none("dispersion.d4"):
return None
if par.is_false("dispersion.d4.sc"):
return None
if device is not None:
if device != numbers.device:
raise DeviceError(
f"Passed device ({device}) and device of `numbers` tensor "
f"({numbers.device}) do not match."
)
dd: DD = {
"device": device,
"dtype": dtype if dtype is not None else get_default_dtype(),
}
param = d4.Param(
**{
"a1": par.get("dispersion.d4.a1"),
"a2": par.get("dispersion.d4.a2"),
"s6": par.get("dispersion.d4.s6"),
"s8": par.get("dispersion.d4.s8"),
"s9": par.get("dispersion.d4.s9"),
"s10": par.get("dispersion.d4.s10"),
}
)
rcov = d4.data.COV_D3(**dd)[numbers]
r4r2 = d4.data.R4R2(**dd)[numbers]
model = d4.model.D4Model(numbers, ref_charges="gfn2", **dd)
cutoff = d4.cutoff.Cutoff(disp2=50.0, disp3=25.0, **dd)
return DispersionD4SC(
param, model=model, rcov=rcov, r4r2=r4r2, cutoff=cutoff, **dd
)