Creating New Components#
In the following, we will explain the process of creating a new tight-binding
component that can be added to the Calculator. For a correct
evaluation within the Calculator, the corresponding methods of
the base Component class must be implemented. We
will show this step by step for the electric field (which itself is already
implemented).
Step 1: Create class.#
Since the electric field interacts with the charges, the electric field
contributes to the charge-dependent Hamiltonian. For the implementation, this
means it is a “self-consistent” component, i.e., it should inherit from the
Interaction class.
class ElectricField(Interaction):
"""
Instantaneous electric field.
"""
Step 2: Add constructor and parameters.#
While actual tight-binding components save the parametrization for each atom
in the attributes/fields, the electric field is fully described by the field
vector. Therefore, the constructor of the electric field only takes the field
vector as an argument and writes it to the field attribute.
class ElectricField(Interaction):
"""
Instantaneous electric field.
"""
field: Tensor
"""Instantaneous electric field vector."""
__slots__ = ["field"]
def __init__(
self,
field: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
device=device if device is None else field.device,
dtype=dtype if dtype is None else field.dtype,
)
self.field = field
The constructor should always call the constructor of the base class with the
dtype and device arguments to ensure consistency. Both arguments bubble
up to the Interaction class, through the
Component class, ending up in the
TensorLike class, which facilitates changing the
device and dtype of all tensors in the class with PyTorch’s well known
to() and type() method. The
TensorLike class also registers the self.dtype,
self.device and self.dd properties.
Do not forget to add the __slots__ attribute to the class. Otherwise, the
to() and type() methods will not work.
All __slots__ should be arguments of the constructor.
Step 3: Create cache.#
The internal Cache should inherit from the cache of the base class
InteractionCache (which makes it
TensorLike again). Correspondingly, the constructor
is similar to the one of the electric field itself. The attributes and
__slots__ are also initialized in the same way. For the electric field,
the cache contains the atom-resolved monopolar and dipolar potentials.
class ElectricFieldCache(InteractionCache):
"""
Restart data for the electric field interaction.
"""
vat: Tensor
"""
Atom-resolved monopolar potental from instantaneous electric field.
"""
vdp: Tensor
"""
Atom-resolved dipolar potential from instantaneous electric field.
"""
__slots__ = ["vat", "vdp"]
def __init__(
self,
vat: Tensor,
vdp: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
device=device if device is None else vat.device,
dtype=dtype if dtype is None else vat.dtype,
)
self.vat = vat
self.vdp = vdp
Step 4: Modify cache for culling in batched SCF.#
This step is less straightforward. Essentially, the
Cache must be updated if a system is removed
from the batch dimension upon convergence within the SCF (“culling”).
Simultanously, all cache variables must be stored to allow restoring them after
the SCF for the final energy evaluation. Correspondingly, we add a simple
Store class and a corresponding attribute (__store) to the
Cache. The __store attribute is initialized to None and will
only be filled when the cull() method is called. The cull() method
takes the indices of systems that are removed from the batch (conv tensor)
and a collection of slicers, which are used for potentially resizing
tensors if the largest system was culled from the batch
(Slicers class). For the atom-resolved monopolar and
dipolar potentials, the corresponding atom-resolved slicers is collected. The
attributes are sliced, while a copy remains in the Store. Restoring the
cache is done by the restore() method, which simply copies the
Store attributes back to the cache.
class ElectricFieldCache(InteractionCache, TensorLike):
"""
Restart data for the electric field interaction.
"""
__store: Store | None
"""Storage for cache (required for culling)."""
vat: Tensor
"""
Atom-resolved monopolar potental from instantaneous electric field.
"""
vdp: Tensor
"""
Atom-resolved dipolar potential from instantaneous electric field.
"""
__slots__ = ["__store", "vat", "vdp"]
def __init__(
self,
vat: Tensor,
vdp: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
device=device if device is None else vat.device,
dtype=dtype if dtype is None else vat.dtype,
)
self.vat = vat
self.vdp = vdp
self.__store = None
class Store:
"""
Storage container for cache containing ``__slots__`` before culling.
"""
vat: Tensor
"""
Atom-resolved monopolar potental from instantaneous electric field.
"""
vdp: Tensor
"""
Atom-resolved dipolar potential from instantaneous electric field.
"""
def __init__(self, vat: Tensor, vdp: Tensor) -> None:
self.vat = vat
self.vdp = vdp
def cull(self, conv: Tensor, slicers: Slicers) -> None:
if self.__store is None:
self.__store = self.Store(self.vat, self.vdp)
slicer = slicers["atom"]
self.vat = self.vat[tuple([~conv, *slicer])]
self.vdp = self.vdp[tuple([~conv, *slicer, ...])]
def restore(self) -> None:
if self.__store is None:
raise RuntimeError("Nothing to restore. Store is empty.")
self.vat = self.__store.vat
self.vdp = self.__store.vdp
Step 5: Populate the cache (get_cache).#
The cachable quantities are computed within the get_cache() method. The
Cache is instantiated and returned. Note that if the interaction is
evaluated within the InteractionList,
numbers and IndexHelper will be passed as argument, too.
This is done to fulfill the different requirements of the caches, while
retaining a (somewhat) consistent API. The electric field cache only needs the
position tensor. Correspondingly, the numbers and ihelp are not stored
in the cachvars tuple, which is used to check if the cache is up-to-date.
@override
def get_cache(
self,
*,
numbers: Tensor | None = None,
positions: Tensor | None = None,
ihelp: IndexHelper | None = None,
) -> ElectricFieldCache:
"""
Create restart data for individual interactions.
Returns
-------
ElectricFieldCache
Restart data for the interaction.
Note
----
If this interaction is evaluated within the `InteractionList`, `numbers`
and `IndexHelper` will be passed as argument, too. The `**_` in the
argument list will absorb those unnecessary arguments which are given
as keyword-only arguments (see `Interaction.get_cache()`).
"""
if positions is None:
raise ValueError("Electric field requires atomic positions.")
cachvars = (positions.detach().clone(), self.field.detach().clone())
if self.cache_is_latest(cachvars) is True:
if not isinstance(self.cache, ElectricFieldCache):
raise TypeError(
f"Cache in {self.label} is not of type '{self.label}."
"Cache'. This can only happen if you manually manipulate "
"the cache."
)
return self.cache
self._cachevars = cachvars
# (nbatch, natoms, 3) * (3) -> (nbatch, natoms)
vat = einsum("...ik,k->...i", positions, self.field)
# (nbatch, natoms, 3)
vdp = self.field.expand_as(positions)
self.cache = ElectricFieldCache(vat, vdp)
return self.cache
Step 6: Implement the energy evaluation.#
The energy from the electric field has a monopolar and a dipolar contribution.
Hence, both a get_atom_energy() and a get_dipole_energy() method
must be implemented. They overwrite the corresponding methods of the base
class, which would evaluate to zero. In general, all methods that are not
implemented in the derived class will evaluate to zero.
@override
def get_monopole_atom_energy(
self, cache: ElectricFieldCache, qat: Tensor, **_: Any
) -> Tensor:
"""
Calculate the monopolar contribution of the electric field energy.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
qat : Tensor
Atomic charges of all atoms.
Returns
-------
Tensor
Atom-wise electric field interaction energies.
"""
return -cache.vat * qat
@override
def get_dipole_atom_energy(
self,
cache: ElectricFieldCache,
qat: Tensor,
qdp: Tensor | None = None,
qqp: Tensor | None = None,
) -> Tensor:
"""
Calculate the dipolar contribution of the electric field energy.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
qat : Tensor
Atom-resolved partial charges (shape: ``(..., nat)``).
qdp : Tensor
Atom-resolved shadow charges (shape: ``(..., nat, 3)``).
qqp : Tensor
Atom-resolved quadrupole moments (shape: ``(..., nat, 6)``).
Returns
-------
Tensor
Atom-wise electric field interaction energies.
"""
assert qdp is not None
# equivalent: torch.sum(-cache.vdp * qdp, dim=-1)
return einsum("...ix,...ix->...i", -cache.vdp, qdp)
Step 7: Implement the potential evaluation.#
Similar to the energy evaluation, the potential evaluation is split into a monopolar and a dipolar contribution (to the charge-dependent Hamiltonian). For API consistency, the charges are passed as a dummy argument.
@override
def get_monopole_atom_potential(
self, cache: ElectricFieldCache, *_: Any, **__: Any
) -> Tensor:
"""
Calculate the electric field potential.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
Returns
-------
Tensor
Atom-wise electric field potential.
"""
return -cache.vat
@override
def get_dipole_atom_potential(
self, cache: ElectricFieldCache, *_: Any, **__: Any
) -> Tensor:
"""
Calculate the electric field dipole potential.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
Returns
-------
Tensor
Atom-wise electric field dipole potential.
"""
return -cache.vdp
Step 8: String representation (optional).#
As good practice, the __str__() and __repr__() methods should be
implemented to provide a human-readable representation of the component.
def __str__(self) -> str:
return f"{self.__class__.__name__}(field={self.field})"
def __repr__(self) -> str:
return str(self)
Step 9: Add to the Calculator.#
To use the electric field in a calculation, it must be added to the Calculator. This is done by passing an instance of the electric field to the constructor of the Calculator.
import torch
from dxtb.typing import DD
from dxtb import Calculator, GFN1_XTB
dd: DD = {"device": torch.device("cpu"), "dtype": torch.double}
field = torch.tensor([0.0, 0.0, 0.0], **dd)
ef = ElectricField(field=field, **dd)
numbers = torch.tensor([3, 1], **dd)
calc = Calculator(
numbers,
GFN1_XTB,
interactions=[ef]
)