# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
External Fields: Electric Field
===============================
Interaction of the charge density with an external electric field.
"""
from __future__ import annotations
import torch
from tad_mctc.math import einsum
from dxtb import IndexHelper
from dxtb._src.typing import Any, Slicers, Tensor, override
from dxtb._src.typing.exceptions import DeviceError, DtypeError
from ..base import Interaction, InteractionCache
__all__ = ["ElectricField", "LABEL_EFIELD", "new_efield"]
LABEL_EFIELD = "ElectricField"
"""Label for the 'ElectricField' interaction, coinciding with the class name."""
class ElectricFieldCache(InteractionCache):
"""
Restart data for the electric field interaction.
"""
__store: Store | None
"""Storage for cache (required for culling)."""
vat: Tensor
"""
Atom-resolved monopolar potental from instantaneous electric field.
"""
vdp: Tensor
"""
Atom-resolved dipolar potential from instantaneous electric field.
"""
__slots__ = ["__store", "vat", "vdp"]
def __init__(
self,
vat: Tensor,
vdp: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
device=device if device is None else vat.device,
dtype=dtype if dtype is None else vat.dtype,
)
self.vat = vat
self.vdp = vdp
self.__store = None
class Store:
"""
Storage container for cache containing ``__slots__`` before culling.
"""
vat: Tensor
"""
Atom-resolved monopolar potental from instantaneous electric field.
"""
vdp: Tensor
"""
Atom-resolved dipolar potential from instantaneous electric field.
"""
def __init__(self, vat: Tensor, vdp: Tensor) -> None:
self.vat = vat
self.vdp = vdp
def cull(self, conv: Tensor, slicers: Slicers) -> None:
if self.__store is None:
self.__store = self.Store(self.vat, self.vdp)
slicer = slicers["atom"]
self.vat = self.vat[tuple([~conv, *slicer])]
self.vdp = self.vdp[tuple([~conv, *slicer, ...])]
def restore(self) -> None:
if self.__store is None:
raise RuntimeError("Nothing to restore. Store is empty.")
self.vat = self.__store.vat
self.vdp = self.__store.vdp
[docs]
class ElectricField(Interaction):
"""
Instantaneous electric field.
"""
field: Tensor
"""Instantaneous electric field vector."""
__slots__ = ["field"]
def __init__(
self,
field: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__(
device=device if device is None else field.device,
dtype=dtype if dtype is None else field.dtype,
)
self.field = field
# pylint: disable=unused-argument
[docs]
@override
def get_cache(
self,
*,
numbers: Tensor | None = None,
positions: Tensor | None = None,
ihelp: IndexHelper | None = None,
) -> ElectricFieldCache:
"""
Create restart data for individual interactions.
Returns
-------
ElectricFieldCache
Restart data for the interaction.
Note
----
If this interaction is evaluated within the `InteractionList`, `numbers`
and `IndexHelper` will be passed as argument, too. The `**_` in the
argument list will absorb those unnecessary arguments which are given
as keyword-only arguments (see `Interaction.get_cache()`).
"""
if positions is None:
raise ValueError("Electric field requires atomic positions.")
cachvars = (positions.detach().clone(), self.field.detach().clone())
if self.cache_is_latest(cachvars) is True:
if not isinstance(self.cache, ElectricFieldCache):
raise TypeError(
f"Cache in {self.label} is not of type '{self.label}."
"Cache'. This can only happen if you manually manipulate "
"the cache."
)
return self.cache
self._cachevars = cachvars
# (nbatch, natoms, 3) * (3) -> (nbatch, natoms)
vat = einsum("...ik,k->...i", positions, self.field)
# (nbatch, natoms, 3)
vdp = self.field.expand_as(positions)
self.cache = ElectricFieldCache(vat, vdp)
return self.cache
[docs]
@override
def get_monopole_atom_energy(
self, cache: ElectricFieldCache, qat: Tensor, **_: Any
) -> Tensor:
"""
Calculate the monopolar contribution of the electric field energy.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
qat : Tensor
Atomic charges of all atoms.
Returns
-------
Tensor
Atom-wise electric field interaction energies.
"""
return -cache.vat * qat
[docs]
@override
def get_dipole_atom_energy(
self,
cache: ElectricFieldCache,
qat: Tensor,
qdp: Tensor | None = None,
qqp: Tensor | None = None,
) -> Tensor:
"""
Calculate the dipolar contribution of the electric field energy.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
qat : Tensor
Atom-resolved partial charges (shape: ``(..., nat)``).
qdp : Tensor
Atom-resolved shadow charges (shape: ``(..., nat, 3)``).
qqp : Tensor
Atom-resolved quadrupole moments (shape: ``(..., nat, 6)``).
Returns
-------
Tensor
Atom-wise electric field interaction energies.
"""
assert qdp is not None
# equivalent: torch.sum(-cache.vdp * qdp, dim=-1)
return einsum("...ix,...ix->...i", -cache.vdp, qdp)
[docs]
@override
def get_monopole_atom_potential(
self, cache: ElectricFieldCache, *_: Any, **__: Any
) -> Tensor:
"""
Calculate the electric field potential.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
Returns
-------
Tensor
Atom-wise electric field potential.
"""
return -cache.vat
[docs]
@override
def get_dipole_atom_potential(
self, cache: ElectricFieldCache, *_: Any, **__: Any
) -> Tensor:
"""
Calculate the electric field dipole potential.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
Returns
-------
Tensor
Atom-wise electric field dipole potential.
"""
return -cache.vdp
def __str__(self) -> str: # pragma: no cover
return f"{self.__class__.__name__}(field={self.field})"
def __repr__(self) -> str: # pragma: no cover
return str(self)
[docs]
def new_efield(
field: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> ElectricField:
"""
Create an instance of the electric field interaction.
Parameters
----------
field : Tensor
Electric field vector consisting of the three cartesian components.
device : torch.device | None, optional
Device to store the tensor on. If ``None`` (default), the device is
inferred from the `field` argument.
dtype : torch.dtype | None, optional
Data type of the tensor. If ``None`` (default), the data type is inferred
from the `field` argument.
Returns
-------
ElectricField
Instance of the electric field interaction.
Raises
------
RuntimeError
Shape of `field` is not a vector of length 3.
"""
if field.shape != torch.Size([3]):
raise RuntimeError("Electric field must be a vector of length 3.")
if device is not None:
if device != field.device:
raise DeviceError(
f"Passed device ({device}) and device of electric field "
f"({field.device}) do not match."
)
if dtype is not None:
if dtype != field.dtype:
raise DtypeError(
f"Passed dtype ({dtype}) and dtype of electric field "
f"({field.dtype}) do not match."
)
return ElectricField(
field,
device=device if device is None else field.device,
dtype=dtype if dtype is None else field.dtype,
)