Source code for dxtb._src.components.classicals.repulsion.rep

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Repulsion: Classes
==================

This module implements the classical repulsion energy term in two flavors. The
first class, class:`.Repulsion`, provides the gradient using PyTorch's
autograd. The second class, class:`.RepulsionAnalytical`, provides an custom
backward with an analytical derivative.

Note
----
The Repulsion class has a cache scope that goes beyond single-point
calculations (geometry optimization, numerical gradients). The atomic numbers
are set upon instantiation (``numbers`` is a property), and the parameters in
the cache are created for only those atomic numbers. The positions, however,
must be supplied to the ``get_energy`` method. Hence, the cache does not become
invalid for different geometries, but only for different atomic numbers.
"""

from __future__ import annotations

import torch
from tad_mctc._version import __tversion__
from tad_mctc.math import einsum

from dxtb._src.typing import Any, Tensor, override

from .base import (
    BaseRepulsion,
    BaseRepulsionCache,
    repulsion_energy,
    repulsion_gradient,
)

__all__ = ["LABEL_REPULSION", "Repulsion", "RepulsionAnalytical"]


LABEL_REPULSION = "Repulsion"
"""
Label for the :class:`.Repulsion` component, coinciding with the class name.
"""


[docs] class Repulsion(BaseRepulsion): """ Representation of the classical repulsion. """
[docs] @override def get_energy( self, positions: Tensor, cache: BaseRepulsionCache, **kwargs: Any ) -> Tensor: """ Get repulsion energy. Parameters ---------- cache : Repulsion.Cache Cache for repulsion. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). atom_resolved : bool Whether to return atom-resolved energy (True) or full matrix (False). Returns ------- Tensor (Atom-resolved) repulsion energy. """ e = repulsion_energy( positions, cache.mask, cache.arep, cache.kexp, cache.zeff, self.cutoff, ) if kwargs.get("atom_resolved", True) is True: return 0.5 * torch.sum(e, dim=-1) return e
class RepulsionAnalytical(Repulsion): """ Representation of the classical repulsion. """ @override def get_energy( self, positions: Tensor, cache: BaseRepulsionCache, atom_resolved: bool = True, ) -> Tensor: """ Get repulsion energy. This function employs the custom autograd class to provide an analytical first derivative. Parameters ---------- cache : Repulsion.Cache Cache for repulsion. positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). atom_resolved : bool Whether to return atom-resolved energy (True) or full matrix (False). Returns ------- Tensor (Atom-resolved) repulsion energy. """ _RepulsionAG = ( RepulsionAG_V1 if __tversion__ < (2, 0, 0) else RepulsionAG_V2 ) # pragma: no cover e = _RepulsionAG.apply( positions, cache.mask, cache.arep, cache.kexp, cache.zeff, self.cutoff, ) assert e is not None if atom_resolved is True: return 0.5 * torch.sum(e, dim=-1) return e class RepulsionAGBase(torch.autograd.Function): """ Base class for the version-specific autograd function for repulsion energy. Different PyTorch versions only require different `forward()` signatures. """ @staticmethod def backward(ctx, grad_out: Tensor) -> tuple[ None | Tensor, # positions None, # mask None | Tensor, # arep None | Tensor, # kexp None | Tensor, # zeff None, # cutoff ]: # initialize gradients with ``None`` positions_bar = arep_bar = kexp_bar = zeff_bar = None # check which of the input variables of `forward()` requires gradients grad_positions, _, grad_arep, grad_kexp, grad_zeff, _ = ( ctx.needs_input_grad ) erep, positions, mask, arep, kexp, zeff = ctx.saved_tensors # analytical gradient for positions if grad_positions: # (n_batch, n_atoms, n_atoms, 3) g = repulsion_gradient( erep, positions, mask, arep, kexp, reduced=False, ) # vjp: (nb, na, na) * (nb, na, na, 3) -> (nb, na, 3) _gi = einsum("...ij,...ijd->...id", grad_out, g) _gj = einsum("...ij,...ijd->...jd", grad_out, g) positions_bar = _gi - _gj # automatic gradient for parameters if grad_arep: (arep_bar,) = torch.autograd.grad( erep, arep, grad_outputs=grad_out, create_graph=True, ) if grad_kexp: (kexp_bar,) = torch.autograd.grad( erep, kexp, grad_outputs=grad_out, create_graph=True, ) if grad_zeff: (zeff_bar,) = torch.autograd.grad( erep, zeff, grad_outputs=grad_out, create_graph=True, ) return positions_bar, None, arep_bar, kexp_bar, zeff_bar, None class RepulsionAG_V1(RepulsionAGBase): """ Autograd function for repulsion energy. """ @staticmethod def forward( ctx, positions: Tensor, mask: Tensor, arep: Tensor, kexp: Tensor, zeff: Tensor, cutoff: float, ) -> Tensor: with torch.enable_grad(): erep = repulsion_energy(positions, mask, arep, kexp, zeff, cutoff) ctx.mark_non_differentiable(mask) ctx.save_for_backward(erep, positions, mask, arep, kexp, zeff) return erep.clone() class RepulsionAG_V2(RepulsionAGBase): """ Autograd function for repulsion energy. """ generate_vmap_rule = True # https://pytorch.org/docs/master/notes/extending.func.html#automatically-generate-a-vmap-rule # should work since we only use PyTorch operations @staticmethod def forward( positions: Tensor, mask: Tensor, arep: Tensor, kexp: Tensor, zeff: Tensor, cutoff: float, ) -> Tensor: with torch.enable_grad(): erep = repulsion_energy(positions, mask, arep, kexp, zeff, cutoff) return erep.clone() @staticmethod def setup_context(ctx, inputs: tuple[Tensor, ...], output: Tensor) -> None: positions, mask, arep, kexp, zeff, _ = inputs erep = output ctx.mark_non_differentiable(mask) ctx.save_for_backward(erep, positions, mask, arep, kexp, zeff)