Source code for dxtb._src.components.classicals.halogen.hal

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Halogen Bond Correction: Class
==============================

This module implements the halogen bond correction class. The
:class:`dxtb.components.Halogen` class is constructed similar to the
:class:`dxtb.components.Repulsion` class.
"""

from __future__ import annotations

import torch
from tad_mctc.batch import pack
from tad_mctc.convert import any_to_tensor
from tad_mctc.data.radii import ATOMIC_RADII
from tad_mctc.typing import Any, Tensor, override

from dxtb import IndexHelper
from dxtb._src.constants import xtb

from ..base import Classical, ClassicalCache, ComponentCache

__all__ = ["Halogen", "LABEL_HALOGEN"]


LABEL_HALOGEN = "Halogen"
"""
Label for the :class:`.Halogen` component, coinciding with the class name.
"""


class HalogenCache(ClassicalCache):
    """Cache for the halogen bond parameters."""

    xbond: Tensor
    """Halogen bond strengths."""

    __slots__ = ["numbers", "xbond"]

    def __init__(
        self,
        numbers: Tensor,
        xbond: Tensor,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__(
            device=device if device is None else xbond.device,
            dtype=dtype if dtype is None else xbond.dtype,
        )
        self.numbers = numbers
        self.xbond = xbond


[docs] class Halogen(Classical): """ Representation of the halogen bond correction. """ halogens: list[int] """Atomic numbers of halogen atoms considered in correction.""" bases: list[int] """Atomic numbers of base atoms considered in correction.""" damp: Tensor """Damping factor in Lennard-Jones like potential.""" rscale: Tensor """Scaling factor for atomic radii.""" bond_strength: Tensor """Halogen bond strengths for unique species.""" cutoff: Tensor | float | int """ Real space cutoff for halogen bonding interactions. :default: :data:`xtb.DEFAULT_XB_CUTOFF` """ __slots__ = ["damp", "rscale", "bond_strength", "cutoff"] def __init__( self, damp: Tensor, rscale: Tensor, bond_strength: Tensor, cutoff: Tensor | float | int = xtb.DEFAULT_XB_CUTOFF, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__(device, dtype) self.damp = damp.to(**self.dd) self.rscale = rscale.to(**self.dd) self.bond_strength = bond_strength.to(**self.dd) self.cutoff = any_to_tensor(cutoff, **self.dd) # element numbers of halogens and bases self.halogens = [17, 35, 53, 85] self.bases = [7, 8, 15, 16]
[docs] @override def get_cache( self, numbers: Tensor, ihelp: IndexHelper | None = None, **kwargs: Any ) -> HalogenCache: """ Store variables for energy calculation. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). ihelp : IndexHelper Helper class for indexing. Returns ------- Repulsion.Cache Cache for halogen bond correction. Note ---- The cache of a classical contribution does not require ``positions`` as it only becomes useful if ``numbers`` remain unchanged and ``positions`` vary, i.e., during geometry optimization. """ if ihelp is None: raise ValueError( "IndexHelper is required for halogen bond correction." ) cachvars = (numbers.detach().clone(),) if self.cache_is_latest(cachvars) is True: if not isinstance(self.cache, HalogenCache): raise TypeError( f"Cache in {self.label} is not of type '{self.label}." "Cache'. This can only happen if you manually manipulate " "the cache." ) return self.cache self._cachevars = cachvars xbond = ihelp.spread_uspecies_to_atom(self.bond_strength) self.cache = HalogenCache(numbers, xbond) return self.cache
[docs] @override def get_energy( self, positions: Tensor, cache: ComponentCache, **_: Any ) -> Tensor: """ Handle batchwise and single calculation of halogen bonding energy. Parameters ---------- positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). cache : ComponentCache Cache for the halogen bond parameters. Returns ------- Tensor Atomwise energy contributions from halogen bonds. """ if not isinstance(cache, HalogenCache): raise TypeError( f"Cache in {self.label} is not of type 'HalogenCache'." ) if cache.numbers.ndim > 1: return pack( [ self._xbond_energy( cache.numbers[_batch], positions[_batch], cache.xbond[_batch], ) for _batch in range(cache.numbers.shape[0]) ] ) else: return self._xbond_energy( cache.numbers, positions, cache.xbond, )
def _xbond_list(self, numbers: Tensor, positions: Tensor) -> Tensor | None: """ Calculate triples for halogen bonding interactions. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). Returns ------- Tensor | None Triples for halogen bonding interactions or ``None`` if no triples where found. Note ---- We cannot use ``self.numbers`` here, because it is not batched. """ adjlist = [] # find all halogen-base pairs for i, i_at in enumerate(numbers): if i_at not in self.halogens: continue for j, j_at in enumerate(numbers): if j_at not in self.bases: continue if torch.norm(positions[i, :] - positions[j, :]) > self.cutoff: continue adjlist.append([i, j, 0]) if len(adjlist) == 0: return None # convert to tensor adj = torch.tensor(adjlist, dtype=torch.long, device=self.device) # find nearest neighbor of halogen for i in range(adj.size(-2)): iat = adj[i][0] dist = torch.tensor(torch.finfo(self.dtype).max, **self.dd) for k, kat in enumerate(numbers): # skip padding if kat == 0: continue r1 = torch.norm(positions[iat, :] - positions[k, :]) if 0.0 < r1 < dist: adj[i, 2] = k dist = r1 return adj def _xbond_energy( self, numbers: Tensor, positions: Tensor, xbond: Tensor, ) -> Tensor: """ Calculate atomwise energy contribution for each triple of halogen bonding interactions. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). bond_strength : Tensor Halogen bond strengths. Returns ------- Tensor Atomwise energy contributions from halogen bonding interactions. Note ---- We cannot use ``self.numbers`` here, because it is not batched. """ halogen_mask = torch.zeros( numbers.shape, device=self.device, dtype=torch.bool, ) for halogen in self.halogens: halogen_mask += numbers == halogen # return if no halogens are present if halogen_mask.nonzero().size(-2) == 0: return torch.zeros(numbers.shape, **self.dd) base_mask = torch.zeros( numbers.shape, device=self.device, dtype=torch.bool, ) for base in self.bases: base_mask += numbers == base # return if no bases are present if base_mask.nonzero().size(-2) == 0: return torch.zeros(numbers.shape, **self.dd) # triples for halogen bonding interactions adj = self._xbond_list(numbers, positions) if adj is None: return torch.zeros(numbers.shape, **self.dd) # parameters rads = ATOMIC_RADII(**self.dd)[numbers] * self.rscale # init tensor for atomwise energies energies = positions.new_zeros(numbers.size(-1)) for i in range(adj.size(-2)): xat = adj[i][0] # index of halogen atom jat = adj[i][1] # index of base atom kat = adj[i][2] # index of nearest neighbor of halogen atom r0xj = rads[xat] + rads[jat] dxj = positions[jat, :] - positions[xat, :] dxk = positions[kat, :] - positions[xat, :] dkj = positions[kat, :] - positions[jat, :] d2xj = torch.sum(dxj * dxj) # distance hal-acc d2xk = torch.sum(dxk * dxk) # distance hal-neighbor d2kj = torch.sum(dkj * dkj) # distance acc-neighbor rxj = torch.sqrt(d2xj) xy = torch.sqrt(d2xk * d2xj) # Lennard-Jones like potential lj6 = torch.pow(r0xj / rxj, 6.0) lj12 = torch.pow(lj6, 2.0) lj = (lj12 - self.damp * lj6) / (1.0 + lj12) # cosine of angle (base-halogen-neighbor) via rule of cosines cosa = (d2xk + d2xj - d2kj) / xy # angle-dependent damping function fdamp = torch.pow(0.5 - 0.25 * cosa, 6.0) energies[xat] += lj * fdamp * xbond[xat] return energies