Source code for dxtb._src.integral.driver.pytorch.overlap
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implementation: Overlap
=======================
PyTorch-based overlap implementations.
"""
from __future__ import annotations
from tad_mctc.convert import symmetrize
from dxtb._src.typing import Tensor
from ...types import OverlapIntegral
from ...utils import snorm
from .base import IntegralPytorch
from .driver import BaseIntDriverPytorch
from .impls import OverlapFunction
__all__ = ["OverlapPytorch"]
[docs]
class OverlapPytorch(OverlapIntegral, IntegralPytorch):
"""
Overlap integral from atomic orbitals.
Use the :meth:`.build` method to calculate the overlap integral. The
returned matrix uses a custom autograd function to calculate the
backward pass with the analytical gradient.
For the full gradient, i.e., a matrix of shape ``(..., norb, norb, 3)``,
the :meth:`.get_gradient` method should be used.
"""
[docs]
def build(self, driver: BaseIntDriverPytorch) -> Tensor:
"""
Overlap calculation of unique shells pairs, using the
McMurchie-Davidson algorithm.
Parameters
----------
driver : BaseIntDriverPytorch
The integral driver for the calculation.
Returns
-------
Tensor
Overlap integral matrix of shape ``(..., norb, norb)``.
"""
super().checks(driver)
if driver.ihelp.batch_mode > 0:
self.matrix = self._batch(driver.eval_ovlp, driver)
else:
self.matrix = self._single(driver.eval_ovlp, driver)
# force symmetry to avoid problems through numerical errors
if self.uplo == "n":
return symmetrize(self.matrix, force=False)
self.norm = snorm(self.matrix)
return self.matrix
[docs]
def get_gradient(self, driver: BaseIntDriverPytorch) -> Tensor:
"""
Overlap gradient calculation of unique shells pairs, using the
McMurchie-Davidson algorithm.
Parameters
----------
driver : BaseIntDriverPytorch
Integral driver for the calculation.
Returns
-------
Tensor
Overlap gradient of shape ``(..., norb, norb, 3)``.
"""
super().checks(driver)
# build norm if not already available
if self.norm is None:
self.build(driver)
if driver.ihelp.batch_mode > 0:
self.gradient = self._batch(driver.eval_ovlp_grad, driver)
else:
self.gradient = self._single(driver.eval_ovlp_grad, driver)
return self.gradient
def _single(
self, fcn: OverlapFunction, driver: BaseIntDriverPytorch
) -> Tensor:
if not isinstance(driver, BaseIntDriverPytorch):
raise RuntimeError("Wrong integral driver selected.")
return fcn(
driver._positions_single,
driver.basis,
driver.ihelp,
self.uplo,
self.cutoff,
)
def _batch(
self, fcn: OverlapFunction, driver: BaseIntDriverPytorch
) -> Tensor:
if not isinstance(driver, BaseIntDriverPytorch):
raise RuntimeError("Wrong integral driver selected.")
# pylint: disable=import-outside-toplevel
from tad_mctc.batch import pack
return pack(
[
fcn(
driver._positions_batch[_batch],
driver._basis_batch[_batch],
driver._ihelp_batch[_batch],
self.uplo,
self.cutoff,
)
for _batch in range(driver.numbers.shape[0])
]
)