# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Dispersion: D3
==============
The DFT-D3 dispersion model.
"""
from __future__ import annotations
import tad_dftd3 as d3
import torch
from tad_mctc.data import radii
from tad_mctc.ncoord import coordination_number, exp_count
from dxtb import IndexHelper
from dxtb._src.typing import Any, CountingFunction, Tensor, override
from ..base import ClassicalCache, ComponentCache
from .base import Dispersion
__all__ = ["DispersionD3", "DispersionD3Cache"]
class DispersionD3Cache(ClassicalCache):
"""
Cache for the dispersion settings.
Note
----
The dispersion parameters (a1, a2, ...) are given in the constructor.
"""
__slots__ = [
"ref",
"rcov",
"rvdw",
"r4r2",
"cutoff",
"counting_function",
"weighting_function",
"damping_function",
]
def __init__(
self,
ref: d3.reference.Reference,
rcov: Tensor,
rvdw: Tensor,
r4r2: Tensor,
cutoff: Tensor,
counting_function: CountingFunction,
weighting_function: d3.typing.WeightingFunction,
damping_function: d3.typing.DampingFunction,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
device=device if device is None else rcov.device,
dtype=dtype if dtype is None else rcov.dtype,
)
self.ref = ref
self.rcov = rcov
self.rvdw = rvdw
self.r4r2 = r4r2
self.cutoff = cutoff
self.counting_function = counting_function
self.weighting_function = weighting_function
self.damping_function = damping_function
[docs]
class DispersionD3(Dispersion):
"""
Representation of the DFT-D3(BJ) dispersion correction
(:class:`.DispersionD3`).
"""
[docs]
@override
def get_cache(
self, numbers: Tensor, ihelp: IndexHelper | None = None, **kwargs: Any
) -> DispersionD3Cache:
"""
Obtain cache for storage of settings.
Settings can be passed as ``kwargs``. The available optional parameters
are the same as in :func:`tad_dftd3.dftd3`, i.e., "ref", "rcov",
"rvdw", and "r4r2".
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
Returns
-------
DispersionD3Cache
Cache for the D3 dispersion.
"""
cachvars = (numbers.detach().clone(),)
if self.cache_is_latest(cachvars) is True:
if not isinstance(self.cache, DispersionD3Cache):
raise TypeError(
f"Cache in {self.label} is not of type '{self.label}."
"Cache'. This can only happen if you manually manipulate "
"the cache."
)
return self.cache
self._cachevars = cachvars
ref = kwargs.pop(
"ref",
d3.reference.Reference(),
).to(**self.dd)
rcov = kwargs.pop(
"rcov",
radii.COV_D3(**self.dd)[numbers],
).to(**self.dd)
rvdw = kwargs.pop(
"rvdw",
radii.VDW_PAIRWISE(**self.dd)[
numbers.unsqueeze(-1), numbers.unsqueeze(-2)
],
).to(**self.dd)
r4r2 = kwargs.pop(
"r4r2",
d3.data.R4R2(**self.dd)[numbers],
).to(**self.dd)
cutoff = kwargs.pop(
"cutoff",
torch.tensor(d3.defaults.D3_CN_CUTOFF, **self.dd),
).to(**self.dd)
cf = kwargs.pop("counting_function", exp_count)
wf = kwargs.pop("weighting_function", d3.model.gaussian_weight)
df = kwargs.pop("damping_function", d3.damping.rational_damping)
self.cache = DispersionD3Cache(
ref, rcov, rvdw, r4r2, cutoff, cf, wf, df
)
return self.cache
[docs]
@override
def get_energy(
self, positions: Tensor, cache: ComponentCache, **kwargs: Any
) -> Tensor:
"""
Get D3 dispersion energy.
Parameters
----------
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
cache : ComponentCache
Dispersion cache containing settings.
Returns
-------
Tensor
Atom-resolved D3 dispersion energy.
"""
if not isinstance(cache, DispersionD3Cache):
raise TypeError(
f"Cache in {self.label} is not of type 'DispersionD3Cache'."
)
cn = coordination_number(
self.numbers,
positions,
counting_function=cache.counting_function,
rcov=cache.rcov,
cutoff=cache.cutoff,
)
weights = d3.model.weight_references(
self.numbers, cn, cache.ref, cache.weighting_function
)
chunk_size = kwargs.pop("chunk_size", None)
c6 = d3.model.atomic_c6(
self.numbers, weights, cache.ref, chunk_size=chunk_size
)
return d3.disp.dispersion(
self.numbers,
positions,
self.param,
c6,
cache.rvdw,
cache.r4r2,
cache.damping_function,
)