# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Coulomb: Isotropic second-order electrostatic energy (ES2)
==========================================================
This module implements the second-order electrostatic energy for GFN1-xTB.
Example
-------
.. code-block:: python
import torch
import dxtb.coulomb.secondorder as es2
from dxtb.coulomb.average import harmonic_average as average
from dxtb import GFN1_XTB, get_element_param
# Define atomic numbers, positions, and charges
numbers = torch.tensor([14, 1, 1, 1, 1])
positions = torch.tensor([
[+0.00000000000000, -0.00000000000000, +0.00000000000000],
[+1.61768389755830, +1.61768389755830, -1.61768389755830],
[-1.61768389755830, -1.61768389755830, -1.61768389755830],
[+1.61768389755830, -1.61768389755830, +1.61768389755830],
[-1.61768389755830, +1.61768389755830, +1.61768389755830],
])
q = torch.tensor([
-8.41282505804719e-2,
2.10320626451180e-2,
2.10320626451178e-2,
2.10320626451179e-2,
2.10320626451179e-2,
])
# Initialize the ES2 energy calculator with parameters
gexp = torch.tensor(GFN1_XTB.charge.effective.gexp)
hubbard = get_element_param(GFN1_XTB.element, "gam")
es = es2.ES2(hubbard=hubbard, average=average, gexp=gexp)
# Calculate energy using the provided atomic charges and positions
cache = es.get_cache(numbers=numbers, positions=positions)
e = es.get_energy(q, cache)
torch.set_printoptions(precision=7)
print(torch.sum(e, dim=-1)) # Output: tensor(0.0005078)
"""
from __future__ import annotations
import torch
from tad_mctc import storch
from tad_mctc.batch import real_pairs
from tad_mctc.exceptions import DeviceError
from tad_mctc.math import einsum
from dxtb import IndexHelper
from dxtb._src.constants import xtb
from dxtb._src.param import Param, ParamModule
from dxtb._src.typing import (
DD,
Any,
Slicers,
Tensor,
TensorLike,
TensorOrTensors,
get_default_dtype,
override,
)
from ..base import Interaction, InteractionCache
from .average import AveragingFunction, averaging_function, harmonic_average
__all__ = ["ES2", "LABEL_ES2", "new_es2"]
LABEL_ES2 = "ES2"
"""Label for the 'ES2' interaction, coinciding with the class name."""
class ES2Cache(InteractionCache, TensorLike):
"""
Cache for Coulomb matrix in ES2.
"""
__store: Store | None
"""Storage for cache (required for culling)."""
mat: Tensor
"""Coulomb matrix."""
shell_resolved: bool
"""Electrostatics is shell-resolved (default: ``True``)."""
__slots__ = ["__store", "mat", "shell_resolved"]
def __init__(
self,
mat: Tensor,
shell_resolved: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
device=device if device is None else mat.device,
dtype=dtype if dtype is None else mat.dtype,
)
self.mat = mat
self.shell_resolved = shell_resolved
self.__store = None
class Store:
"""
Storage container for cache containing ``__slots__`` before culling.
"""
mat: Tensor
"""Coulomb matrix"""
def __init__(self, mat: Tensor) -> None:
self.mat = mat
def cull(self, conv: Tensor, slicers: Slicers) -> None:
if self.__store is None:
self.__store = self.Store(self.mat)
_slicer = slicers["shell"] if self.shell_resolved else slicers["atom"]
slicer = tuple([~conv, *_slicer, *_slicer])
self.mat = self.mat[slicer]
def restore(self) -> None:
if self.__store is None:
raise RuntimeError("Nothing to restore. Store is empty.")
self.mat = self.__store.mat
[docs]
class ES2(Interaction):
"""
Isotropic second-order electrostatic energy (ES2).
"""
hubbard: Tensor
"""Hubbard parameters of all elements."""
lhubbard: Tensor | None
"""
Shell-resolved scaling factors for Hubbard parameters.
:default: ``None`` (i.e., no shell resolution).
"""
average: AveragingFunction
"""
Function to use for averaging the Hubbard parameters.
:default: :func:`dxtb._src.components.interactions.average.harmonic_average`
"""
gexp: Tensor
"""
Exponent of the second-order Coulomb interaction.
:default: 2.0
"""
shell_resolved: bool
"""
Whether electrostatics is shell-resolved.
:default: ``True``
"""
__slots__ = [
"hubbard",
"lhubbard",
"average",
"gexp",
"shell_resolved",
]
def __init__(
self,
hubbard: Tensor,
lhubbard: Tensor | None = None,
average: AveragingFunction = harmonic_average,
gexp: Tensor = torch.tensor(xtb.DEFAULT_ES2_GEXP),
shell_resolved: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(device, dtype)
self.hubbard = hubbard.to(**self.dd)
self.lhubbard = lhubbard if lhubbard is None else lhubbard.to(**self.dd)
self.gexp = gexp.to(**self.dd)
self.average = average
self.shell_resolved = shell_resolved and lhubbard is not None
# pylint: disable=unused-argument
[docs]
@override
def get_cache(
self,
*,
numbers: Tensor | None = None,
positions: Tensor | None = None,
ihelp: IndexHelper | None = None,
) -> ES2Cache:
"""
Obtain the cache object containing the Coulomb matrix.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
ihelp : IndexHelper
Index mapping for the basis set.
Returns
-------
ES2Cache
Cache object for second order electrostatics.
Note
----
The cache of an interaction requires ``positions`` as they do not change
during the self-consistent charge iterations.
"""
if numbers is None:
raise ValueError("Atomic numbers are required for ES2 cache.")
if positions is None:
raise ValueError("Positions are required for ES2 cache.")
if ihelp is None:
raise ValueError("IndexHelper is required for ES2 cache creation.")
cachvars = (numbers.detach().clone(), positions.detach().clone())
if self.cache_is_latest(cachvars) is True:
if not isinstance(self.cache, ES2Cache):
raise TypeError(
f"Cache in {self.label} is not of type '{self.label}."
"Cache'. This can only happen if you manually manipulate "
"the cache."
)
return self.cache
# if the cache is built, store the cachvar for validation
self._cachevars = cachvars
self.cache = ES2Cache(
(
self.get_shell_coulomb_matrix(numbers, positions, ihelp)
if self.shell_resolved
else self.get_atom_coulomb_matrix(numbers, positions, ihelp)
),
shell_resolved=self.shell_resolved,
)
return self.cache
[docs]
def get_atom_coulomb_matrix(
self, numbers: Tensor, positions: Tensor, ihelp: IndexHelper
) -> Tensor:
"""
Calculate the atom-resolved Coulomb matrix.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
ihelp : IndexHelper
Index mapping for the basis set.
Returns
-------
Tensor
Coulomb matrix.
"""
# only calculate mask once and save it for backward
mask = real_pairs(numbers, mask_diagonal=True)
mat = CoulombMatrixAG.apply(
mask,
positions,
ihelp,
self.hubbard,
self.lhubbard,
self.gexp,
self.average,
self.shell_resolved,
)
assert mat is not None
return mat
[docs]
def get_shell_coulomb_matrix(
self, numbers: Tensor, positions: Tensor, ihelp: IndexHelper
) -> Tensor:
"""
Calculate the Coulomb matrix.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
ihelp : IndexHelper
Index mapping for the basis set.
Returns
-------
Tensor
Coulomb matrix.
"""
if self.lhubbard is None:
raise ValueError("No 'lhubbard' parameters set.")
# only calculate mask once and save it for backward
mask = real_pairs(numbers, mask_diagonal=True)
mat = coulomb_matrix_shell(
mask,
positions,
ihelp,
self.hubbard,
self.lhubbard,
self.gexp,
self.average,
)
# mat = CoulombMatrixAG(
# mask,
# positions,
# ihelp,
# self.hubbard,
# self.lhubbard,
# self.gexp,
# self.average,
# self.shell_resolved,
# )
assert mat is not None
return mat
[docs]
@override
def get_monopole_atom_energy(
self, cache: ES2Cache, qat: Tensor, **_: Any
) -> Tensor:
return (
0.5 * qat * self.get_monopole_atom_potential(cache, qat)
if not self.shell_resolved
else torch.zeros_like(qat)
)
[docs]
@override
def get_monopole_shell_energy(
self, cache: ES2Cache, qat: Tensor, **_: Any
) -> Tensor:
return (
0.5 * qat * self.get_monopole_shell_potential(cache, qat)
if self.shell_resolved
else torch.zeros_like(qat)
)
[docs]
@override
def get_monopole_atom_potential(
self,
cache: ES2Cache,
qat: Tensor,
qdp: Tensor | None = None,
qqp: Tensor | None = None,
) -> Tensor:
"""
Calculate atom-resolved potential. Zero if this interaction is
shell-resolved.
Parameters
----------
cache : ES2Cache
Cache object for second order electrostatics.
qat : Tensor
Atom-resolved partial charges (shape: ``(..., nat)``).
Returns
-------
Tensor
Atom-resolved potential.
"""
return (
torch.zeros_like(qat)
if self.shell_resolved
else einsum("...ik,...k->...i", cache.mat, qat)
)
[docs]
@override
def get_monopole_shell_potential(
self,
cache: ES2Cache,
qsh: Tensor,
qdp: Tensor | None = None,
qqp: Tensor | None = None,
) -> Tensor:
"""
Calculate shell-resolved potential. Zero if this interaction is only
atom-resolved.
Parameters
----------
cache : ES2Cache
Cache object for second order electrostatics.
qsh : Tensor
Shell-resolved partial charges.
Returns
-------
Tensor
Shell-resolved potential.
"""
return (
einsum("...ik,...k->...i", cache.mat, qsh)
if self.shell_resolved
else torch.zeros_like(qsh)
)
[docs]
@override
def get_atom_gradient(
self,
charges: Tensor,
positions: Tensor,
cache: ES2Cache,
grad_outputs: TensorOrTensors | None = None,
retain_graph: bool | None = True,
create_graph: bool | None = None,
) -> Tensor:
"""
Calculates nuclear gradient of an second order electrostatic energy
contribution via PyTorch's autograd engine.
Parameters
----------
charges : Tensor
Atom-resolved partial charges (shape: ``(..., nat)``).
positions : Tensor
Nuclear positions. Needs ``requires_grad=True``.
cache : ES2Cache
Cache object for second order electrostatics.
grad_out : Tensor | None
Gradient of previous computation, i.e., "vector" in VJP of this
gradient computation. Defaults to ``None``.
Returns
-------
Tensor
Nuclear gradient of energy.
Raises
------
RuntimeError
``positions`` tensor does not have ``requires_grad=True``.
"""
if self.shell_resolved:
return torch.zeros_like(positions)
energy = self.get_monopole_atom_energy(cache, charges)
return self._gradient(
energy,
positions,
grad_outputs=grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
)
[docs]
@override
def get_shell_gradient(
self,
charges: Tensor,
positions: Tensor,
cache: ES2Cache,
grad_outputs: TensorOrTensors | None = None,
retain_graph: bool | None = True,
create_graph: bool | None = None,
) -> Tensor:
"""
Calculates nuclear gradient of an second order electrostatic energy
contribution via PyTorch's autograd engine.
Parameters
----------
charges : Tensor
Shell-resolved partial charges.
positions : Tensor
Nuclear positions. Needs ``requires_grad=True``.
cache : ES2Cache
Cache object for second order electrostatics.
grad_out : Tensor | None
Gradient of previous computation, i.e., "vector" in VJP of this
gradient computation.
Returns
-------
Tensor
Nuclear gradient of energy.
Raises
------
RuntimeError
``positions`` tensor does not have ``requires_grad=True``.
"""
if not self.shell_resolved:
return torch.zeros_like(positions)
energy = self.get_monopole_shell_energy(cache, charges)
return self._gradient(
energy,
positions,
grad_outputs=grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
)
def _gradient(
self,
energy: Tensor,
positions: Tensor,
grad_outputs: TensorOrTensors | None = None,
retain_graph: bool | None = True,
create_graph: bool | None = None,
) -> Tensor:
"""
Calculates nuclear gradient of an second order electrostatic energy
contribution via PyTorch's autograd engine.
Parameters
----------
energy : Tensor
Shell-resolved energy.
positions : Tensor
Nuclear positions. Needs ``requires_grad=True``.
grad_out : Tensor | None
Gradient of previous computation, i.e., "vector" in VJP of this
gradient computation.
Returns
-------
Tensor
Nuclear gradient of energy.
Raises
------
RuntimeError
``positions`` tensor does not have ``requires_grad=True``.
"""
if positions.requires_grad is False:
raise RuntimeError("Position tensor needs ``requires_grad=True``.")
# avoid autograd call if energy is zero (autograd fails anyway)
if torch.equal(energy, torch.zeros_like(energy)):
return torch.zeros_like(positions)
if create_graph is None:
create_graph = torch.is_grad_enabled()
if grad_outputs is None:
grad_outputs = torch.ones_like(energy)
(gradient,) = torch.autograd.grad(
energy,
positions,
grad_outputs=grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
)
return gradient
# DEPRECATED
def _get_atom_gradient(
self,
numbers: Tensor,
positions: Tensor,
charges: Tensor,
cache: ES2Cache,
) -> Tensor:
if self.shell_resolved:
return torch.zeros_like(positions)
zero = torch.tensor(0.0, device=positions.device, dtype=positions.dtype)
mask = real_pairs(numbers, mask_diagonal=True)
distances = torch.where(
mask,
storch.cdist(positions, positions, p=2),
zero,
)
# (n_batch, atoms_i, atoms_j, 3)
rij = torch.where(
mask.unsqueeze(-1),
positions.unsqueeze(-2) - positions.unsqueeze(-3),
zero,
)
# (n_batch, atoms_i) -> (n_batch, atoms_i, 1)
charges = charges.unsqueeze(-1)
# (n_batch, atoms_i, atoms_j) * (n_batch, atoms_i, 1)
# every column is multiplied by the charge vector
dmat = (
-(distances ** (self.gexp - 2.0)) * cache.mat * cache.mat**self.gexp
) * charges
# (n_batch, atoms_i, atoms_j) -> (n_batch, atoms_i, atoms_j, 3)
dmat = dmat.unsqueeze(-1) * rij
# (n_batch, atoms_i, atoms_j, 3) -> (n_batch, atoms_i, 3)
return einsum("...ijx,...jx->...ix", dmat, charges)
# DEPRECATED
def _get_shell_gradient(
self,
numbers: Tensor,
positions: Tensor,
charges: Tensor,
cache: ES2Cache,
ihelp: IndexHelper,
) -> Tensor:
if not self.shell_resolved:
return torch.zeros_like(positions)
dd: DD = {"device": positions.device, "dtype": positions.dtype}
zero = torch.tensor(0.0, **dd)
eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd)
mask = real_pairs(numbers, mask_diagonal=True)
# all distances to the power of "gexp" (R^2_AB from Eq.26)
distances = ihelp.spread_atom_to_shell(
torch.where(
mask,
storch.cdist(positions, positions, p=2),
eps,
),
(-1, -2),
)
# (n_batch, shells_i, shells_j, 3)
positions = ihelp.spread_atom_to_shell(positions, dim=-2, extra=True)
mask = ihelp.spread_atom_to_shell(mask, (-2, -1))
rij = torch.where(
mask.unsqueeze(-1),
positions.unsqueeze(-2) - positions.unsqueeze(-3),
zero,
)
# (n_batch, shells_i) -> (n_batch, shells_i, 1)
charges = charges.unsqueeze(-1)
# (n_batch, shells_i, shells_j) * (n_batch, shells_i, 1)
# every column is multiplied by the charge vector
dmat = (
-(distances ** (self.gexp - 2.0)) * cache.mat * cache.mat**self.gexp
) * charges
# (n_batch, shells_i, shells_j) -> (n_batch, shells_i, shells_j, 3)
dmat = dmat.unsqueeze(-1) * rij
# (n_batch, shells_i, shells_j, 3) -> (n_batch, atoms, shells_j, 3)
dmat = ihelp.reduce_shell_to_atom(dmat, dim=-3, extra=True)
# (n_batch, atoms, shells_j, 3) -> (n_batch, atoms, 3)
return einsum("...ijx,...jx->...ix", dmat, charges)
def coulomb_matrix_atom(
mask: Tensor,
positions: Tensor,
ihelp: IndexHelper,
hubbard: Tensor,
gexp: Tensor,
average: AveragingFunction,
) -> Tensor:
"""
Calculate the atom-resolved Coulomb matrix.
Parameters
----------
mask : Tensor
Mask from Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
ihelp : IndexHelper
Index mapping for the basis set.
hubbard : Tensor
Hubbard parameters of all elements.
gexp: Tensor
Exponent of the second-order Coulomb interaction (default: 2.0).
average: AveragingFunction
Function to use for averaging the Hubbard parameters (default:
:func:`dxtb.components.interactions.coulomb.average.harmonic_average`).
Returns
-------
Tensor
Coulomb matrix.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd)
zero = torch.tensor(0.0, **dd)
h = ihelp.spread_uspecies_to_atom(hubbard)
dist = storch.cdist(positions, positions, p=2)
# all distances to the power of "gexp" (R^2_AB from Eq.26)
dist_gexp = torch.where(
mask,
# eps to avoid nan in double backward (negative base?)
torch.pow(dist + eps, gexp),
eps,
)
# re-include diagonal for hardness
mask = mask + torch.diag_embed(torch.ones_like(h).type(torch.bool))
# Eq.30: averaging function for hardnesses (Hubbard parameter)
avg = torch.where(mask, average(h + eps), eps)
# Eq.26: Coulomb matrix
tmp = dist_gexp + torch.where(mask, torch.pow(avg, -gexp), eps)
return torch.where(mask, 1.0 / torch.pow(tmp, 1.0 / gexp), zero)
def coulomb_matrix_atom_gradient(
mask: Tensor, positions: Tensor, mat: Tensor, gexp: Tensor
) -> Tensor:
"""
Nuclear gradient of atom-resolved Coulomb matrix.
Parameters
----------
mask : Tensor
Mask from Atomic numbers for all atoms in the system (shape:
``(..., nat)``).
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
mat : Tensor
Atom-resolved Coulomb matrix.
gexp: Tensor
Exponent of the second-order Coulomb interaction (default: 2.0).
Returns
-------
Tensor
Derivative of atom-resolved Coulomb matrix. The derivative has the
following shape: ``(n_batch, atoms_i, atoms_j, 3)``.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
zero = torch.tensor(0.0, **dd)
distances = torch.where(
mask,
storch.cdist(positions, positions, p=2),
zero,
)
# (n_batch, atoms_i, atoms_j, 3)
rij = torch.where(
mask.unsqueeze(-1),
positions.unsqueeze(-2) - positions.unsqueeze(-3),
zero,
)
# (n_batch, atoms_i, atoms_j)
dmat = -(distances ** (gexp - 2.0)) * mat * mat**gexp
# (n_batch, atoms_i, atoms_j) -> (n_batch, atoms_i, atoms_j, 3)
return dmat.unsqueeze(-1) * rij
def coulomb_matrix_shell(
mask: Tensor,
positions: Tensor,
ihelp: IndexHelper,
hubbard: Tensor,
lhubbard: Tensor,
gexp: Tensor,
average: AveragingFunction,
) -> Tensor:
"""
Calculate the shell-resolved Coulomb matrix.
Parameters
----------
mask : Tensor
Mask from Atomic numbers for all atoms in the system (shape:
``(..., nat)``).
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
ihelp : IndexHelper
Index mapping for the basis set.
hubbard : Tensor
Hubbard parameters of all elements.
lhubbard: Tensor
Shell-resolved scaling factors for Hubbard parameters (default:
``None``, i.e., no shell resolution).
gexp: Tensor
Exponent of the second-order Coulomb interaction (default: 2.0).
average: AveragingFunction
Function to use for averaging the Hubbard parameters (default:
:func:`dxtb._src.components.interactions.average.harmonic_average`).
Returns
-------
Tensor
Coulomb matrix.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
zero = torch.tensor(0.0, **dd)
eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd)
lh = ihelp.spread_ushell_to_shell(lhubbard)
h = lh * ihelp.spread_uspecies_to_shell(hubbard)
dist = storch.cdist(positions, positions, p=2)
# all distances to the power of "gexp" (R^2_AB from Eq.26)
dist_gexp = ihelp.spread_atom_to_shell(
torch.where(
mask,
# eps to avoid nan in double backward (negative base?)
torch.pow(dist + eps, gexp),
eps,
),
(-1, -2),
)
# re-include diagonal for hardness
mask = ihelp.spread_atom_to_shell(
mask
+ torch.diag_embed(
torch.ones_like(ihelp.atom_to_unique).type(torch.bool)
),
(-2, -1),
)
# Eq.30: averaging function for hardnesses (Hubbard parameter)
avg = torch.where(mask, average(h + eps), eps)
# Eq.26: Coulomb matrix
tmp = dist_gexp + torch.where(mask, torch.pow(avg, -gexp), eps)
return torch.where(mask, 1.0 / torch.pow(tmp, 1.0 / gexp), zero)
def coulomb_matrix_shell_gradient(
mask: Tensor,
positions: Tensor,
mat: Tensor,
ihelp: IndexHelper,
gexp: Tensor,
) -> Tensor:
"""
Nuclear gradient of shell-resolved Coulomb matrix.
Parameters
----------
mask : Tensor
Mask from Atomic numbers for all atoms in the system (shape:
``(..., nat)``).
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
mat : Tensor
Shell-resolved Coulomb matrix.
ihelp : IndexHelper
Index mapping for the basis set.
gexp: Tensor
Exponent of the second-order Coulomb interaction (default: 2.0).
Returns
-------
Tensor
Derivative of shell-resolved Coulomb matrix. The derivative has the
following shape: ``(n_batch, shell_i, shell_j, 3)``.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}
zero = torch.tensor(0.0, **dd)
eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd)
# all distances to the power of "gexp" (R^2_AB from Eq.26)
distances = ihelp.spread_atom_to_shell(
torch.where(mask, storch.cdist(positions, positions, p=2), eps),
(-1, -2),
)
# (n_batch, shells_i, shells_j, 3)
positions = ihelp.spread_atom_to_shell(positions, dim=-2, extra=True)
mask = ihelp.spread_atom_to_shell(mask, (-2, -1))
rij = torch.where(
mask.unsqueeze(-1),
positions.unsqueeze(-2) - positions.unsqueeze(-3),
zero,
)
# (n_batch, shells_i, shells_j) * (n_batch, shells_i, 1)
dmat = -(distances ** (gexp - 2.0)) * mat * mat**gexp
# (n_batch, shells_i, shells_j) -> (n_batch, shells_i, shells_j, 3)
return dmat.unsqueeze(-1) * rij
# pylint: disable=abstract-method,arguments-differ
class CoulombMatrixAG(torch.autograd.Function):
"""
Autograd function for Coulomb matrix.
"""
@staticmethod
def forward(
ctx,
mask: Tensor,
positions: Tensor,
ihelp: IndexHelper,
hubbard: Tensor,
lhubbard: Tensor,
gexp: Tensor,
average: AveragingFunction,
shell_resolved: bool,
) -> Tensor:
with torch.enable_grad():
if shell_resolved:
mat = coulomb_matrix_shell(
mask, positions, ihelp, hubbard, lhubbard, gexp, average
)
else:
mat = coulomb_matrix_atom(
mask, positions, ihelp, hubbard, gexp, average
)
# save tensor variables the intended way
ctx.save_for_backward(mat, mask, positions, gexp, hubbard, lhubbard)
# save non-tensor variables (required in backward) directly
ctx.shell_resolved = shell_resolved
ctx.ihelp = ihelp
return mat.clone()
@staticmethod
def backward(ctx, grad_out: Tensor) -> tuple[
None, # mask
None | Tensor, # positions
None, # ihelp
None | Tensor, # hubbard
None | Tensor, # lhubbard
None | Tensor, # gexp
None, # average
None, # shell_resolved
]:
# initialize gradients with ``None``
positions_bar = hubbard_bar = lhubbard_bar = gexp_bar = None
# check which of the input variables of `forward()` requires gradients
(
_,
grad_positions,
_,
grad_hubbard,
grad_lhubbard,
grad_gexp,
_,
_,
) = ctx.needs_input_grad
mat, mask, positions, gexp, hubbard, lhubbard = ctx.saved_tensors
shell_resolved: bool = ctx.shell_resolved
ihelp: IndexHelper = ctx.ihelp
# analytical gradient for positions
if grad_positions:
# (n_batch, n, n, 3)
if shell_resolved:
g = coulomb_matrix_shell_gradient(
mask, positions, mat, ihelp, gexp
)
else:
g = coulomb_matrix_atom_gradient(mask, positions, mat, gexp)
# vjp: (nb, n, n) * (nb, n, n, 3) -> (nb, n, 3)
_gi = einsum("...ij,...ijd->...id", grad_out, g)
_gj = einsum("...ij,...ijd->...jd", grad_out, g)
if shell_resolved:
positions_bar = ihelp.reduce_shell_to_atom(
_gi - _gj, dim=-2, extra=True
)
else:
positions_bar = _gi - _gj
# automatic gradient for parameters
if grad_hubbard:
(hubbard_bar,) = torch.autograd.grad(
mat,
hubbard,
grad_outputs=grad_out,
create_graph=True,
)
if grad_lhubbard:
(lhubbard_bar,) = torch.autograd.grad(
mat,
lhubbard,
grad_outputs=grad_out,
create_graph=True,
)
if grad_gexp:
(gexp_bar,) = torch.autograd.grad(
mat,
gexp,
grad_outputs=grad_out,
create_graph=True,
)
return (
None,
positions_bar,
None,
hubbard_bar,
lhubbard_bar,
gexp_bar,
None,
None,
)
[docs]
def new_es2(
unique: Tensor,
par: Param | ParamModule,
shell_resolved: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> ES2 | None:
"""
Create new instance of :class:`.ES2`.
Parameters
----------
unique : Tensor
Unique elements in the system (shape: ``(nunique,)``).
par : Param | ParamModule
Representation of an extended tight-binding model.
shell_resolved: bool
Electrostatics is shell-resolved.
Returns
-------
ES2 | None
Instance of the ES2 class or ``None`` if no ES2 is used.
"""
dd: DD = {
"device": device,
"dtype": dtype if dtype is not None else get_default_dtype(),
}
# compatibility with previous version based on `Param`
if not isinstance(par, ParamModule):
par = ParamModule(par, **dd)
if "charge" not in par or par.is_none("charge"):
return None
if device is not None:
if device != unique.device:
raise DeviceError(
f"Passed device ({device}) and device of `unique` tensor "
f"({unique.device}) do not match."
)
hubbard = par.get_elem_param(unique, "gam")
lhubbard = (
par.get_elem_param(unique, "lgam") if shell_resolved is True else None
)
return ES2(
hubbard,
lhubbard,
average=averaging_function[par.get("charge.effective.average")],
gexp=par.get("charge.effective.gexp"),
**dd,
)