Source code for dxtb._src.components.interactions.list

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Container for interactions.
"""

from __future__ import annotations

import torch

from dxtb import IndexHelper
from dxtb._src.typing import (
    Any,
    Literal,
    Slicers,
    Tensor,
    TensorOrTensors,
    overload,
    override,
)

from ..list import ComponentList, ComponentListCache
from ..utils import _docstring_reset, _docstring_update
from .base import Interaction
from .container import Charges, Potential
from .coulomb.secondorder import ES2, LABEL_ES2
from .coulomb.thirdorder import ES3, LABEL_ES3
from .dispersion.d4sc import LABEL_DISPERSIOND4SC, DispersionD4SC
from .field.efield import LABEL_EFIELD, ElectricField
from .field.efieldgrad import LABEL_EFIELD_GRAD, ElectricFieldGrad

__all__ = ["InteractionList", "InteractionListCache"]


[docs] class InteractionListCache(ComponentListCache): """ Restart data for individual interactions, extended by subclasses as needed. """
[docs] def cull(self, conv: Tensor, slicers: Slicers) -> None: """ Cull all interaction caches. Parameters ---------- conv : Tensor Mask of converged systems. """ for cache in self.values(): cache.cull(conv, slicers)
[docs] def restore(self) -> None: """ Restore all interaction caches. """ for cache in self.values(): cache.restore()
[docs] class InteractionList(ComponentList[Interaction]): """ List of interactions. """
[docs] @override def get_energy( self, charges: Charges | Tensor, cache: InteractionListCache, ihelp: IndexHelper, ) -> Tensor: """ Compute the energy for a list of interactions. Parameters ---------- charges : Charges | Tensor Collection of charges. Monopolar partial charges are orbital-resolved. ihelp : IndexHelper Index mapping for the basis set. cache : InteractionListCache Restart data for the interaction. Returns ------- Tensor Atom-resolved energy vector for orbital partial charges. """ if isinstance(charges, Tensor): charges = Charges(mono=charges) if len(self.components) <= 0: return ihelp.reduce_orbital_to_atom(torch.zeros_like(charges.mono)) return torch.stack( [ interaction.get_energy(cache[interaction.label], charges, ihelp) for interaction in self.components ] ).sum(dim=0)
[docs] def get_energy_as_dict( self, charges: Charges, cache: InteractionListCache, ihelp: IndexHelper ) -> dict[str, Tensor]: """ Compute the energy for a list of interactions. Parameters ---------- charges : Charges Collection of charges. Monopolar partial charges are orbital-resolved. ihelp : IndexHelper Index mapping for the basis set. cache : InteractionListCache Restart data for the interaction. Returns ------- Tensor Energy vector for each orbital partial charge. """ if len(self.components) <= 0: return {"none": torch.zeros_like(charges.mono)} return { interaction.label: interaction.get_energy( cache[interaction.label], charges, ihelp ) for interaction in self.components }
[docs] @override def get_gradient( self, charges: Charges, positions: Tensor, cache: InteractionListCache, ihelp: IndexHelper, grad_outputs: TensorOrTensors | None = None, retain_graph: bool | None = True, create_graph: bool | None = None, ) -> Tensor: """ Calculate gradient for a list of interactions. Parameters ---------- charges : Charges Collection of charges. Monopolar partial charges are orbital-resolved. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). cache : InteractionListCache Restart data for the interaction. ihelp : IndexHelper Index mapping for the basis set. Returns ------- Tensor Nuclear gradient of all interactions. """ if len(self.components) <= 0: return torch.zeros_like(positions) return torch.stack( [ interaction.get_gradient( charges, positions, cache[interaction.label], ihelp, grad_outputs=grad_outputs, retain_graph=retain_graph, create_graph=create_graph, ) for interaction in self.components ] ).sum(dim=0)
[docs] @override def get_cache( self, numbers: Tensor, positions: Tensor, ihelp: IndexHelper ) -> InteractionListCache: """ Create restart data for individual interactions. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). ihelp: IndexHelper Index mapping for the basis set. Returns ------- InteractionListCache Restart data for the interactions. """ cache = InteractionListCache() cache.update( **{ interaction.label: interaction.get_cache( numbers=numbers, positions=positions, ihelp=ihelp ) for interaction in self.components } ) return cache
[docs] def get_potential( self, cache: InteractionListCache, charges: Charges, ihelp: IndexHelper ) -> Potential: """ Compute the potential for a list of interactions. Parameters ---------- cache : InteractionListCache Restart data for the interactions. charges : Charges Collection of charges. Monopolar partial charges are orbital-resolved. ihelp : IndexHelper Index mapping for the basis set. Returns ------- Tensor Potential vector for each orbital partial charge. """ # create empty potential pot = Potential( torch.zeros_like(charges.mono), dipole=None, quad=None, batch_mode=ihelp.batch_mode, ) # exit with empty potential if no interactions present if len(self.components) <= 0: return pot # add up potentials from all interactions for interaction in self.components: p = interaction.get_potential( cache[interaction.label], charges, ihelp ) pot += p return pot
########################################################################### @overload def get_interaction( self, name: Literal["DispersionD4SC"] ) -> DispersionD4SC: ... @overload def get_interaction( self, name: Literal["ElectricField"] ) -> ElectricField: ... @overload def get_interaction( self, name: Literal["ElectricFieldGrad"] ) -> ElectricFieldGrad: ... @overload def get_interaction(self, name: Literal["ES2"]) -> ES2: ... @overload def get_interaction(self, name: Literal["ES3"]) -> ES3: ...
[docs] @override # generic implementation for typing def get_interaction(self, name: str) -> Interaction: return super().get_interaction(name)
###########################################################################
[docs] @_docstring_reset def reset_d4sc(self) -> Interaction: """Reset tensor attributes to a detached clone of the current state.""" return self.reset(LABEL_DISPERSIOND4SC)
[docs] @_docstring_reset def reset_efield(self) -> Interaction: """Reset tensor attributes to a detached clone of the current state.""" return self.reset(LABEL_EFIELD)
[docs] @_docstring_reset def reset_efield_grad(self) -> Interaction: """Reset tensor attributes to a detached clone of the current state.""" return self.reset(LABEL_EFIELD_GRAD)
[docs] @_docstring_reset def reset_es2(self) -> Interaction: """Reset tensor attributes to a detached clone of the current state.""" return self.reset(LABEL_ES2)
[docs] @_docstring_reset def reset_es3(self) -> Interaction: """Reset tensor attributes to a detached clone of the current state.""" return self.reset(LABEL_ES3)
###########################################################################
[docs] @_docstring_update def update_d4sc(self, **kwargs: Any) -> Interaction: return self.update(LABEL_DISPERSIOND4SC, **kwargs)
[docs] @_docstring_update def update_efield( self, *, field: Tensor | None = None, ) -> Interaction: return self.update(LABEL_EFIELD, field=field)
[docs] @_docstring_update def update_efield_grad( self, *, field_grad: Tensor | None = None, ) -> Interaction: return self.update(LABEL_EFIELD_GRAD, field_grad=field_grad)
[docs] @_docstring_update def update_es2(self, **kwargs: Any) -> Interaction: return self.update(LABEL_ES2, **kwargs)
[docs] @_docstring_update def update_es3(self, **kwargs: Any) -> Interaction: return self.update(LABEL_ES3, **kwargs)