Source code for dxtb._src.components.interactions.dispersion.d4sc

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Interactions: Self-consistent D4 Dispersion
===========================================

Self-consistent D4 dispersion correction.
"""

from __future__ import annotations

from typing import Any

import tad_dftd4 as d4
import tad_dftd4.defaults as d4_defaults
import torch
from tad_mctc.data import PAULING
from tad_mctc.exceptions import DeviceError
from tad_mctc.math import einsum
from tad_mctc.ncoord import coordination_number, erf_count
from tad_mctc.typing import (
    DD,
    CountingFunction,
    Tensor,
    TensorLike,
    get_default_dtype,
    override,
)

from dxtb import IndexHelper
from dxtb._src.param import Param, ParamModule
from dxtb._src.typing import Slicers

from ..base import Interaction, InteractionCache

__all__ = ["DispersionD4SC", "LABEL_DISPERSIOND4SC", "new_d4sc"]


LABEL_DISPERSIOND4SC = "DispersionD4SC"
"""Label for the :class:`.DispersionD4SC` interaction, coinciding with the class name."""


class DispersionD4SCCache(InteractionCache, TensorLike):
    """
    Restart data for the :class:`.DispersionD4SC` interaction.

    Note
    ----
    The dispersion parameters (a1, a2, ...) are given in the dispersion
    class constructor.
    """

    __store: Store | None
    """Storage for cache (required for culling)."""

    cn: Tensor
    """Coordination number of every atom."""

    dispmat: Tensor
    """
    Dispersion matrix. This quantity is almost equal to the dispersion energy,
    except for multiplication with C6 and C8.
    """

    model: d4.model.D4Model
    """
    Model for the D4 dispersion correction.
    Same object as in the `.DispersionD4SC` class.
    """

    __slots__ = ["__store", "cn", "dispmat", "model"]

    def __init__(
        self,
        cn: Tensor,
        dispmat: Tensor,
        model: d4.model.D4Model,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ) -> None:
        super().__init__(
            device=device if device is None else cn.device,
            dtype=dtype if dtype is None else cn.dtype,
        )

        self.cn = cn
        self.dispmat = dispmat

        # Model is the same object as in the DispersionD4SC class
        self.model = model

        self.__store = None

    class Store:
        """
        Storage container for cache containing ``__slots__`` before culling.
        """

        cn: Tensor
        """Coordination number of every atom."""

        dispmat: Tensor
        """
        Dispersion matrix. This quantity is almost equal to the dispersion
        energy, except for multiplication with C6 and C8.
        """

        def __init__(
            self, cn: Tensor, dispmat: Tensor, model: d4.model.D4Model
        ) -> None:
            self.cn = cn
            self.dispmat = dispmat

            # only store numbers to get a different object
            self.numbers = model.numbers

    def cull(self, conv: Tensor, slicers: Slicers) -> None:
        if self.__store is None:
            self.__store = self.Store(self.cn, self.dispmat, self.model)

        slicer = slicers["atom"]
        self.cn = self.cn[tuple([~conv, *slicer])]
        self.dispmat = self.dispmat[tuple([~conv, *slicer, *slicer])]

        self.model.numbers = self.model.numbers[tuple([~conv, *slicer])]

    def restore(self) -> None:
        if self.__store is None:
            raise RuntimeError("Nothing to restore. Store is empty.")

        self.cn = self.__store.cn
        self.dispmat = self.__store.dispmat
        self.model.numbers = self.__store.numbers


[docs] class DispersionD4SC(Interaction): """ Self-consistent D4 dispersion correction (:class:`.DispersionD4SC`). """ param: d4.Param """Dispersion parameters.""" model: d4.model.D4Model """Model for the D4 dispersion correction.""" rcov: Tensor """Covalent radii of all atoms.""" r4r2: Tensor """R4/R2 ratio of all atoms.""" cutoff: d4.cutoff.Cutoff """Real-space cutoff for the D4 dispersion correction.""" counting_function: CountingFunction """ Counting function for the coordination number. :default: :func:`tad_mctc.ncoord.erf_count` """ damping_function: d4.damping.Damping """ Damping function for the dispersion correction. :default: :func:`d4.damping.RationalDamping` """ __slots__ = [ "param", "model", "rcov", "r4r2", "cutoff", "counting_function", "damping_function", ] def __init__( self, param: d4.Param, model: d4.model.D4Model, rcov: Tensor, r4r2: Tensor, cutoff: d4.cutoff.Cutoff, counting_function: CountingFunction = erf_count, damping_function: d4.damping.Damping = d4.damping.RationalDamping(), device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__(device, dtype) self.param = param self.model = model self.rcov = rcov self.r4r2 = r4r2 self.cutoff = cutoff self.counting_function = counting_function self.damping_function = damping_function # pylint: disable=unused-argument
[docs] @override def get_cache( self, *, numbers: Tensor | None = None, positions: Tensor | None = None, ihelp: IndexHelper | None = None, **_, ) -> DispersionD4SCCache: """ Create restart data for individual interactions. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). ihelp : IndexHelper Index mapping for the basis set. Returns ------- DispersionD4SCCache Restart data for the interaction. Note ---- If the :class:`.DispersionD4SC` interaction is evaluated within the :class:`dxtb.components.InteractionList`, ``positions`` will be passed as an argument, too. Hence, it is necessary to absorb the ``positions`` in the signature of the function (also see :meth:`dxtb.components.Interaction.get_cache`). """ if numbers is None: raise ValueError( "Atomic numbers are required for DispersionD4SC cache." ) if positions is None: raise ValueError("Positions are required for ES2 cache.") cachvars = (numbers.detach().clone(), positions.detach().clone()) if self.cache_is_latest(cachvars) is True: if not isinstance(self.cache, DispersionD4SCCache): raise TypeError( f"Cache in {self.label} is not of type '{self.label}." "Cache'. This can only happen if you manually manipulate " "the cache." ) return self.cache # if the cache is built, store the cachevar for validation self._cachevars = cachvars en = PAULING(**self.dd)[numbers] endiff = torch.abs(en.unsqueeze(-2) - en.unsqueeze(-1)) weight = d4_defaults.D4_K4 * torch.exp( -((endiff + d4_defaults.D4_K5) ** 2.0) / d4_defaults.D4_K6 ) cn = coordination_number( numbers, positions, counting_function=self.counting_function, rcov=self.rcov, cutoff=self.cutoff.cn, pair_weight=weight, ) # tblite: disp/d4.f90::get_dispersion_matrix # Instead of multiplying with C6 (from `get_atomic_c6`), we multiply # with the reference C6 coefficients that have not been multiplied with # the Gaussian weights yet. Correspondingly, we have to set the C6 # argument of `dispersion2` to 1. edisp = d4.dispersion.dispersion2( numbers, positions, self.param, torch.ones((*numbers.shape, numbers.shape[-1]), **self.dd), self.r4r2, as_matrix=True, ) dispmat = edisp.unsqueeze(-1).unsqueeze(-1) * self.model.rc6 self.cache = DispersionD4SCCache(cn, dispmat, self.model) return self.cache
[docs] @override def get_monopole_atom_energy( self, cache: InteractionCache, qat: Tensor, **_: Any ) -> Tensor: """ Calculate the D4 dispersion correction energy. Parameters ---------- cache : DispersionD4SCCache Restart data for the interaction. qat : Tensor Atomic charges of all atoms. Returns ------- Tensor Atomwise D4 dispersion correction energies. """ if not isinstance(cache, DispersionD4SCCache): raise TypeError( f"Cache in {self.label} is not of type 'DispersionD4SCCache'." ) # `numbers` in model are updated in cache (for culling) weights = self.model.weight_references(cache.cn, qat) return 0.5 * einsum( "...ijab,...ia,...jb->...j", *(cache.dispmat, weights, weights), optimize=[(0, 1), (0, 1)], )
[docs] @override def get_monopole_atom_potential( self, cache: InteractionCache, qat: Tensor, *_: Any, **__: Any ) -> Tensor: """ Calculate the D4 dispersion correction potential. Parameters ---------- cache : DispersionD4SCCache Restart data for the interaction. qat : Tensor Atomic charges of all atoms. Returns ------- Tensor Atomwise dispersion correction potential. """ if not isinstance(cache, DispersionD4SCCache): raise TypeError( f"Cache in {self.label} is not of type 'DispersionD4SCCache'." ) weights, dgwdq = self.model.weight_references( cache.cn, qat, with_dgwdq=True ) return einsum( "...ijab,...jb,...ia->...i", cache.dispmat, weights, dgwdq )
[docs] def new_d4sc( numbers: Tensor, par: Param | ParamModule, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> DispersionD4SC | None: """ Create new instance of :class:`.DispersionD4SC`. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). par : Param | ParamModule Representation of an extended tight-binding model. Returns ------- DispersionD4SC | None Instance of the :class:`.DispersionD4SC` class or ``None`` if no :class:`.DispersionD4SC` is used. """ dd: DD = { "device": device, "dtype": dtype if dtype is not None else get_default_dtype(), } # compatibility with previous version based on `Param` if not isinstance(par, ParamModule): par = ParamModule(par, **dd) if "dispersion" not in par or par.is_none("dispersion"): return None if par.is_none("dispersion.d4"): return None if par.is_false("dispersion.d4.sc"): return None if device is not None: if device != numbers.device: raise DeviceError( f"Passed device ({device}) and device of `numbers` tensor " f"({numbers.device}) do not match." ) dd: DD = { "device": device, "dtype": dtype if dtype is not None else get_default_dtype(), } param = d4.Param( **{ "a1": par.get("dispersion.d4.a1"), "a2": par.get("dispersion.d4.a2"), "s6": par.get("dispersion.d4.s6"), "s8": par.get("dispersion.d4.s8"), "s9": par.get("dispersion.d4.s9"), "s10": par.get("dispersion.d4.s10"), } ) rcov = d4.data.COV_D3(**dd)[numbers] r4r2 = d4.data.R4R2(**dd)[numbers] model = d4.model.D4Model(numbers, ref_charges="gfn2", **dd) cutoff = d4.cutoff.Cutoff(disp2=50.0, disp3=25.0, **dd) return DispersionD4SC( param, model=model, rcov=rcov, r4r2=r4r2, cutoff=cutoff, **dd )