Source code for dxtb._src.components.base

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Components: Base Class
======================

Base class for all tight-binding components.
"""

from __future__ import annotations

import torch

from dxtb.__version__ import __tversion__
from dxtb._src.typing import Any, Tensor, TensorLike
from dxtb._src.utils.misc import get_all_slots

__all__ = ["Component", "ComponentCache"]


[docs] class ComponentCache(TensorLike): """Cache of a component.""" def __len__(self) -> int: slots = get_all_slots(self) return len([s for s in slots if not s.startswith("_")]) def __str__(self) -> str: # pragma: no cover slots = get_all_slots(self) s = ", ".join(s for s in slots if not s.startswith("_")) return f"{self.__class__.__name__}({s})" def __repr__(self) -> str: # pragma: no cover return str(self)
[docs] class Component(TensorLike): """ Base class for all tight-binding terms. """ label: str """Label for the tight-binding component.""" _cache: ComponentCache | None """Cache for the component.""" _cachevars: tuple[Tensor, ...] | None """ Cache variable for the component. If this variable changes, the cache has to be rebuild. """ _cache_enabled: bool """Flag to enable or disable the cache.""" __slots__ = ["label", "_cache", "_cachevars", "_cache_enabled"] def __init__( self, device: torch.device | None = None, dtype: torch.dtype | None = None, *, _cache: ComponentCache | None = None, _cachevars: tuple[Tensor, ...] | None = None, ): super().__init__(device, dtype) self.label = self.__class__.__name__ self._cache = _cache self._cachevars = _cachevars self._cache_enabled = True ############################################################################ @property def cache(self) -> ComponentCache | None: """Cache for the interaction.""" return self._cache @cache.setter def cache(self, value: ComponentCache | None) -> None: self._cache = value ############################################################################
[docs] def update(self, **kwargs: Any) -> None: """ Update the attributes of the :class:`.Component` instance. This method updates the attributes of the :class:`.Component` instance based on the provided keyword arguments. Only the attributes defined in ``__slots__`` can be updated. Parameters ---------- kwargs : dict[str, Any] Keyword arguments where keys are attribute names and values are the new values for those attributes. Valid keys are those defined in ``__slots__`` of this class. Raises ------ AttributeError If any key in kwargs is not an attribute defined in ``__slots__``. Examples -------- .. code-block:: python import torch from dxtb.components.field import ElectricField ef = ElectricField(field=torch.tensor([0.0, 0.0, 0.0])) ef.update(field=torch.tensor([1.0, 0.0, 0.0])) """ for key, value in kwargs.items(): if key is None: continue if key in self.__slots__: setattr(self, key, value) else: raise AttributeError( f"Cannot update '{key}' of the '{self.__class__.__name__}' " "interaction. Invalid attribute.\nPossible attributes are: " f"{', '.join(self.__slots__)}." ) self.cache_invalidate()
[docs] def reset(self) -> None: """ Reset the tensor attributes of the :class:`dxtb.components.base.Component` instance to their original states or to specified values. This method iterates through the attributes defined in ``__slots__`` and resets any tensor attributes to a detached clone of their original state. The `requires_grad` status of each tensor is preserved. Examples -------- .. code-block:: python import torch from dxtb.components.base.field import ElectricField ef = ElectricField(field=torch.tensor([0.0, 0.0, 0.0])) ef.reset() Notes ----- Only tensor attributes defined in ``__slots__`` are reset. Non-tensor attributes are ignored. Attempting to reset an attribute not defined in ``__slots__`` or providing a non-tensor value in `kwargs` will not raise an error; the method will simply ignore these cases and proceed with the reset operation for valid tensor attributes. """ for slot in self.__slots__: attr = getattr(self, slot) if isinstance(attr, Tensor): reset = attr.detach().clone() reset.requires_grad = attr.requires_grad setattr(self, slot, reset) self.cache_invalidate()
############################################################################
[docs] def cache_is_latest( self, cvars: tuple[Tensor, ...], tol: float | None = None ) -> bool: """ Check if the driver is set up and updated. Parameters ---------- positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). Returns ------- bool Flag for set up status. """ if self._cache_enabled is False: return False if self.cache is None: return False if self._cachevars is None: return False for v1, v2 in zip(cvars, self._cachevars): # functorch makes problems here, just disable cache for now if __tversion__ >= (1, 13, 0): if torch._C._functorch.is_gradtrackingtensor(v1): return False if torch._C._functorch.is_gradtrackingtensor(v2): return False if v1.dtype != v2.dtype: return False if v1.device != v2.device: return False if v1.dtype in (torch.int64, torch.int32, torch.long): if torch.equal(v1, v2) is False: return False else: if tol is None: tol = torch.finfo(v1.dtype).eps ** 0.75 if (v2 - v1).abs().sum() > tol: return False return True
[docs] def cache_invalidate(self) -> None: """Invalidate the cache to require renewed setup.""" self._cache = None self._cachevars = None
@property def cache_is_setup(self) -> bool: """Whether the cache has been set up.""" return self._cache is not None and self._cachevars is not None
[docs] def cache_enable(self) -> None: """Enable the cache.""" self._cache_enabled = True
[docs] def cache_disable(self) -> None: """Disable the cache.""" self._cache_enabled = False
############################################################################ def __str__(self) -> str: return f"{self.__class__.__name__}({self.label})" def __repr__(self) -> str: return str(self)