# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Analytical linearized Poisson-Boltzmann model
=============================================
This module implements implicit solvation models of the generalized Born type.
Example
-------
.. code-block:: python
import torch
from dxtb.solvation.alpb import GeneralizedBorn
numbers = torch.tensor([14, 1, 1, 1, 1])
positions = torch.tensor([
[+0.00000000000000, -0.00000000000000, +0.00000000000000],
[+1.61768389755830, +1.61768389755830, -1.61768389755830],
[-1.61768389755830, -1.61768389755830, -1.61768389755830],
[+1.61768389755830, -1.61768389755830, +1.61768389755830],
[-1.61768389755830, +1.61768389755830, +1.61768389755830],
])
charges = torch.tensor([
-8.41282505804719e-2,
2.10320626451180e-2,
2.10320626451178e-2,
2.10320626451179e-2,
2.10320626451179e-2,
])
# Initialize the GeneralizedBorn model with a solvent dielectric constant
gb = GeneralizedBorn(numbers, torch.tensor(78.9), kernel="still")
# Build cache and use it for energy calculation
cache = gb.get_cache(numbers=numbers, positions=positions)
energy = gb.get_monopole_atom_energy(cache, charges)
total_energy = energy.sum(-1)
print(total_energy) # Output: tensor(-5.0762e-05)
"""
from __future__ import annotations
import torch
from tad_mctc import storch
from tad_mctc.batch import real_pairs
from tad_mctc.convert import any_to_tensor
from tad_mctc.data import VDW_D3
from tad_mctc.math import einsum
from dxtb import IndexHelper
from dxtb._src.param import Param, ParamModule
from dxtb._src.typing import (
DD,
Any,
Tensor,
TensorLike,
TensorOrTensors,
get_default_dtype,
override,
)
from dxtb._src.typing.exceptions import DeviceError
from ..base import Interaction, InteractionCache
from .born import get_born_radii
alpha = 0.571412
DEFAULT_KERNEL = "p16"
DEFAULT_ALPB = True
DEFAULT_BORN_SCALE = 1.0
DEFAULT_BORN_OFFSET = 0.0
__all__ = ["GeneralizedBorn", "new_solvation"]
# @torch.jit.script
def p16_kernel(r1: Tensor, ab: Tensor) -> Tensor:
"""
Evaluate P16 interaction kernel: 1 / (R + √ab / (1 + ζR/(16·√ab))¹⁶)
Parameters
----------
r1 : Tensor
Distance between all atom pairs
ab : Tensor
Product of Born radii
Returns
-------
Tensor
Interaction kernel between all atom pairs
"""
ab = torch.sqrt(ab)
arg = torch.pow(ab / (ab + r1 * 1.028 / 16), 16)
return 1.0 / (r1 + ab * arg)
# @torch.jit.script
def still_kernel(r1: Tensor, ab: Tensor) -> Tensor:
"""
Evaluate Still interaction kernel: 1 / √(R² + ab · exp[R²/(4·ab)])
Parameters
----------
r1 : Tensor
Distance between all atom pairs
ab : Tensor
Product of Born radii
Returns
-------
Tensor
Interaction kernel between all atom pairs
"""
r2 = torch.pow(r1, 2)
arg = torch.exp(-0.25 * r2 / ab)
return 1.0 / torch.sqrt(r2 + ab * arg)
born_kernel = {"p16": p16_kernel, "still": still_kernel}
def get_adet(positions: Tensor, rad: Tensor) -> Tensor:
"""
Calculate electrostatic shape function based on the moments of inertia
of solid spheres.
Parameters
----------
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
rad : Tensor
Radii of all atoms.
Returns
-------
Tensor
Electrostatic shape function.
"""
vol = torch.pow(rad, 3)
center = (positions * vol.unsqueeze(-1)).sum(-2) / vol.sum(-1)
displ = positions - center.unsqueeze(-2)
diag = torch.pow(displ, 2).sum(-1) + 2 * torch.pow(rad, 2) / 5
inertia = (
vol.unsqueeze(-1).unsqueeze(-2)
* (
-displ.unsqueeze(-1) * displ.unsqueeze(-2)
+ torch.diag_embed(diag.unsqueeze(-1).expand(*positions.shape))
)
).sum(-3)
adet = (
+inertia[..., 0, 0] * inertia[..., 1, 1] * inertia[..., 2, 2]
- inertia[..., 0, 0] * inertia[..., 1, 2] * inertia[..., 2, 1]
- inertia[..., 0, 1] * inertia[..., 1, 0] * inertia[..., 2, 2]
+ inertia[..., 0, 1] * inertia[..., 1, 2] * inertia[..., 2, 0]
+ inertia[..., 0, 2] * inertia[..., 1, 0] * inertia[..., 2, 1]
- inertia[..., 0, 2] * inertia[..., 1, 1] * inertia[..., 2, 0]
)
return torch.sqrt(5 * torch.pow(adet, 1 / 3) / (2 * vol.sum(-1)))
class GeneralizedBornCache(InteractionCache, TensorLike):
"""
Restart data for the generalized Born solvation model.
"""
__slots__ = ["mat"]
mat: Tensor
"""Coulomb matrix."""
def __init__(
self,
mat: Tensor,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__(
device=device if device is None else mat.device,
dtype=dtype if dtype is None else mat.dtype,
)
self.mat = mat
[docs]
class GeneralizedBorn(Interaction):
"""
Implicit solvation model for describing the interaction with a dielectric continuum.
"""
kernel: str
"""Interaction kernel."""
alpbet: Tensor
"""Finite dielectric constant correction."""
keps: Tensor
"""Dielectric function."""
born_kwargs: dict[str, Any]
"""Parameters for Born radii integration."""
def __init__(
self,
numbers: Tensor,
dielectric_constant: Tensor,
alpb: bool = DEFAULT_ALPB,
kernel: str = DEFAULT_KERNEL,
born_scale: float = DEFAULT_BORN_SCALE,
born_offset: float = DEFAULT_BORN_OFFSET,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
**kwargs: Any,
) -> None:
super().__init__(device, dtype)
self.alpbet = (
alpha / dielectric_constant
if alpb
else torch.tensor(0.0, **self.dd)
)
self.keps = (1 / dielectric_constant - 1) / (1 + self.alpbet)
self.kernel = kernel
self.born_kwargs = {
"rvdw": kwargs.get("rvdw", VDW_D3(**self.dd)[numbers]),
"born_scale": born_scale,
"born_offset": born_offset,
}
# pylint: disable=unused-argument
[docs]
@override
def get_cache(
self,
*,
numbers: Tensor | None = None,
positions: Tensor | None = None,
ihelp: IndexHelper | None = None,
) -> GeneralizedBornCache:
"""
Create restart data for individual interactions.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
positions : Tensor
Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``).
Returns
-------
GeneralizedBornCache
Cache object for second order electrostatics.
Note
----
If the :class:`.GeneralizedBorn` interaction is evaluated
within the :class:`dxtb.components.InteractionList`, the
:class:`dxtb.IndexHelper` will be passed as an argument, too. Hence,
it is necessary to absorb the ``positions`` in the signature of the
function.
"""
if numbers is None:
raise ValueError("Atomic numbers are required for cache.")
if positions is None:
raise ValueError("Atomic positions are required for cache.")
cachvars = (numbers.detach().clone(), positions.detach().clone())
if self.cache_is_latest(cachvars) is True:
if not isinstance(self.cache, GeneralizedBornCache):
raise TypeError(
f"Cache in {self.label} is not of type '{self.label}."
"Cache'. This can only happen if you manually manipulate "
"the cache."
)
return self.cache
# if the cache is built, store the positions for validation
self._cachevars = cachvars
born = get_born_radii(numbers, positions, **self.born_kwargs)
eps = torch.tensor(torch.finfo(positions.dtype).eps, **self.dd)
mask = real_pairs(numbers, mask_diagonal=False)
dist = torch.where(mask, storch.cdist(positions, p=2), eps)
ab = torch.where(mask, born.unsqueeze(-1) * born.unsqueeze(-2), eps)
mat = self.keps * born_kernel[self.kernel](dist, ab)
if self.alpbet > 0:
adet = get_adet(positions, self.born_kwargs["rvdw"])
mat += self.keps * self.alpbet * adet.unsqueeze(-1).unsqueeze(-2)
self.cache = GeneralizedBornCache(mat)
return self.cache
[docs]
@override
def get_monopole_atom_energy(
self, cache: GeneralizedBornCache, qat: Tensor, **_: Any
) -> Tensor:
return 0.5 * qat * self.get_monopole_atom_potential(cache, qat)
[docs]
@override
def get_monopole_atom_potential(
self,
cache: GeneralizedBornCache,
qat: Tensor,
qdp: Tensor | None = None,
qqp: Tensor | None = None,
) -> Tensor:
return einsum("...ik,...k->...i", cache.mat, qat)
# TODO: Implement gradient before using solvation in SCF
[docs]
def get_atom_gradient(
self,
charges: Tensor,
positions: Tensor,
cache: GeneralizedBornCache,
grad_outputs: TensorOrTensors | None = None,
retain_graph: bool | None = True,
create_graph: bool | None = None,
) -> Tensor:
raise NotImplementedError("Solvation gradient not implemented")
[docs]
def new_solvation(
numbers: Tensor,
par: Param | ParamModule,
dielectric_constant: Tensor | float | int = 80.3,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> GeneralizedBorn | None:
"""
Create new instance of the generalized Born solvation model.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
par : Param | ParamModule
Representation of an extended tight-binding model.
Returns
-------
GeneralizedBorn | None
Instance of the `GeneralizedBorn` class or ``None`` if no solvation model
is used.
"""
dd: DD = {
"device": device,
"dtype": dtype if dtype is not None else get_default_dtype(),
}
# compatibility with previous version based on `Param`
if not isinstance(par, ParamModule):
par = ParamModule(par, **dd)
if "solvation" not in par or par.is_none("solvation"):
return None
if par.is_none("solvation.alpb") or par.is_false("solvation.alpb"):
return None
if device is not None:
if device != numbers.device:
raise DeviceError(
f"Passed device ({device}) and device of electric field "
f"({numbers.device}) do not match."
)
return GeneralizedBorn(
numbers,
dielectric_constant=any_to_tensor(dielectric_constant),
alpb=par.get("solvation.alpb.alpb"),
kernel=par.get("solvation.alpb.kernel"),
born_scale=par.get("solvation.alpb.born_scale"),
born_offset=par.get("solvation.alpb.born_offset"),
**dd,
)