Source code for dxtb._src.integral.driver.libcint.overlap

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implementation: Overlap
=======================

Overlap implementation based on `libcint`.
"""

from __future__ import annotations

import torch
from tad_mctc.batch import pack
from tad_mctc.math import einsum

from dxtb._src.typing import Any, Tensor

from ...types import OverlapIntegral
from ...utils import snorm
from .base import IntegralLibcint
from .driver import IntDriverLibcint

__all__ = ["OverlapLibcint"]


[docs] class OverlapLibcint(OverlapIntegral, IntegralLibcint): """ Overlap integral from atomic orbitals. Use the :meth:`build` method to calculate the overlap integral. The returned matrix uses a custom autograd function to calculate the backward pass with the analytical gradient. For the full gradient, i.e., a matrix of shape ``(..., norb, norb, 3)``, the :meth:`get_gradient` method should be used. """
[docs] def build(self, driver: IntDriverLibcint, **_: Any) -> Tensor: """ Calculation of overlap integral using libcint. Returns ------- driver : IntDriverLibcint The integral driver for the calculation. Returns ------- Tensor Overlap integral matrix of shape ``(..., norb, norb)``. """ super().checks(driver) # pylint: disable=import-outside-toplevel from dxtb._src.exlibs import libcint # batched mode if driver.ihelp.batch_mode > 0: assert isinstance(driver.drv, list) slist = [libcint.overlap(d) for d in driver.drv] nlist = [snorm(s) for s in slist] self.norm = pack(nlist) self.matrix = pack(slist) return self.matrix # single mode assert isinstance(driver.drv, libcint.LibcintWrapper) self.matrix = libcint.overlap(driver.drv) self.norm = snorm(self.matrix) return self.matrix
[docs] def get_gradient(self, driver: IntDriverLibcint, **_: Any) -> Tensor: """ Overlap gradient calculation using libcint. Parameters ---------- driver : IntDriverLibcint The integral driver for the calculation. Returns ------- Tensor Overlap gradient of shape ``(..., norb, norb, 3)``. """ super().checks(driver) # pylint: disable=import-outside-toplevel from dxtb._src.exlibs import libcint # build norm if not already available if self.norm is None: self.build(driver) def fcn(driver: libcint.LibcintWrapper) -> Tensor: # (3, norb, norb) grad = libcint.int1e("ipovlp", driver) # Move xyz dimension to last, which is required for the # reduction (only works with extra dimension in last) return -einsum("...xij->...ijx", grad) # batched mode if driver.ihelp.batch_mode > 0: if not isinstance(driver.drv, list): raise RuntimeError( "IndexHelper on integral driver is batched, but the driver " "instance itself not." ) if driver.ihelp.batch_mode == 1: self.gradient = pack([fcn(d) for d in driver.drv]) return self.gradient elif driver.ihelp.batch_mode == 2: self.gradient = torch.stack([fcn(d) for d in driver.drv]) return self.gradient raise ValueError(f"Unknown batch mode '{driver.ihelp.batch_mode}'.") # single mode if not isinstance(driver.drv, libcint.LibcintWrapper): raise RuntimeError( "IndexHelper on integral driver is not batched, but the " "driver instance itself seems to be batched." ) self.gradient = fcn(driver.drv) return self.gradient