# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
External Fields: Field Gradient
===============================
Interaction of the charge density with external electric field gradient.
"""
from __future__ import annotations
import torch
from tad_mctc.exceptions import DeviceError, DtypeError
from tad_mctc.math import einsum
from dxtb import IndexHelper
from dxtb._src.typing import Tensor, TensorLike, override
from ..base import Interaction, InteractionCache
__all__ = ["ElectricFieldGrad", "LABEL_EFIELD_GRAD", "new_efield_grad"]
LABEL_EFIELD_GRAD = "ElectricFieldGrad"
"""Label for the 'ElectricField' interaction, coinciding with the class name."""
class ElectricFieldCache(InteractionCache, TensorLike):
"""
Restart data for the electric field interaction.
Note
----
This cache is not culled, and hence, does not contain a `Store`.
"""
efg: Tensor
"""Reshaped electric field gradient."""
__slots__ = ["efg"]
def __init__(
self,
efg: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__(
device=device if device is None else efg.device,
dtype=dtype if dtype is None else efg.dtype,
)
self.efg = efg
[docs]
class ElectricFieldGrad(Interaction):
"""
Electric field gradient.
"""
field_grad: Tensor
"""Electric field gradient."""
__slots__ = ["field_grad"]
def __init__(
self,
field_grad: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__(
device=device if device is None else field_grad.device,
dtype=dtype if dtype is None else field_grad.dtype,
)
self.field_grad = field_grad
# pylint: disable=unused-argument
[docs]
@override
def get_cache(
self,
*,
numbers: Tensor | None = None,
positions: Tensor | None = None,
ihelp: IndexHelper | None = None,
) -> ElectricFieldCache:
"""
Create restart data for individual interactions.
Returns
-------
ElectricFieldCache
Restart data for the interaction.
Note
----
Here, this is only a dummy.
"""
cachvars = (self.field_grad.detach().clone(),)
if self.cache_is_latest(cachvars) is True:
if not isinstance(self.cache, ElectricFieldCache):
raise TypeError(
f"Cache in {self.label} is not of type '{self.label}."
"Cache'. This can only happen if you manually manipulate "
"the cache."
)
return self.cache
self._cachevars = cachvars
efg = self.field_grad[torch.tril_indices(3, 3).unbind()]
self.cache = ElectricFieldCache(efg)
return self.cache
# TODO: This is probably not correct...
[docs]
@override
def get_quadrupole_atom_energy(
self,
cache: ElectricFieldCache,
qat: Tensor,
qdp: Tensor | None = None,
qqp: Tensor | None = None,
) -> Tensor:
"""
Calculate the quadrupolar contribution of the electric field energy.
Parameters
----------
cache : ElectricFieldCache
Restart data for the interaction.
qat : Tensor
Atom-resolved partial charges (shape: ``(..., nat)``).
qdp : Tensor
Atom-resolved shadow charges (shape: ``(..., nat, 3)``).
qqp : Tensor
Atom-resolved quadrupole moments (shape: ``(..., nat, 6)``).
Returns
-------
Tensor
Atom-wise electric field interaction energies.
"""
assert qqp is not None
# equivalent: torch.sum(-cache.vqp * charges, dim=-1)
return 0.5 * einsum("...x,...ix->...i", cache.efg, qqp)
def __str__(self) -> str: # pragma: no cover
return f"{self.__class__.__name__}(field_grad={self.field_grad})"
def __repr__(self) -> str: # pragma: no cover
return str(self)
[docs]
def new_efield_grad(
field_grad: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> ElectricFieldGrad:
"""
Create an instance of the electric field gradient interaction.
Parameters
----------
field_grad : Tensor
Electric field gradient consisting of the 3x3 cartesian components.
device : torch.device | None, optional
Device to store the tensor on. If ``None`` (default), the device is
inferred from the `field` argument.
dtype : torch.dtype | None, optional
Data type of the tensor. If ``None`` (default), the data type is inferred
from the `field` argument.
Returns
-------
ElectricFieldGrad
Instance of the electric field gradient interaction.
Raises
------
RuntimeError
Shape of `field_grad` is not a 3x3 matrix.
"""
if field_grad.shape != torch.Size((3, 3)):
raise RuntimeError("Electric field gradient must be a 3 by 3 matrix.")
if device is not None:
if device != field_grad.device:
raise DeviceError(
f"Passed device ({device}) and device of electric field "
f"gradient ({field_grad.device}) do not match."
)
if dtype is not None:
if dtype != field_grad.dtype:
raise DtypeError(
f"Passed dtype ({dtype}) and dtype of electric field "
f"gradient ({field_grad.dtype}) do not match."
)
return ElectricFieldGrad(
field_grad,
device=device if device is None else field_grad.device,
dtype=dtype if dtype is None else field_grad.dtype,
)