Source code for dxtb._src.integral.driver.libcint.driver

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Driver: Libcint
===============

Base class for a `libcint`-based integral implementation
Calculation and modification of multipole integrals.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from dxtb import IndexHelper
from dxtb._src.basis.bas import Basis
from dxtb._src.typing import Tensor, TypeAlias, overload
from dxtb._src.utils import is_basis_list

from ...base import IntDriver
from .base import LibcintImplementation

if TYPE_CHECKING:
    from dxtb._src.exlibs import libcint

    ListAtomCGTOBasis: TypeAlias = list[libcint.AtomCGTOBasis]


__all__ = ["BaseIntDriverLibcint", "IntDriverLibcint"]


[docs] class BaseIntDriverLibcint(LibcintImplementation, IntDriver): """ Implementation of `libcint`-based integral driver. """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.drv = None
[docs] def setup(self, positions: Tensor, **kwargs) -> None: """ Run the `libcint`-specific driver setup. Parameters ---------- positions : Tensor Cartesian coordinates of all atoms (shape: ``(..., nat, 3)``). """ # setup `Basis` class if not already done if self._basis is None: self.basis = Basis(self.numbers, self.par, self.ihelp, **self.dd) # create atomic basis set in libcint format mask = kwargs.pop("mask", None) atombases = self.basis.create_libcint(positions, mask=mask) self.drv = self._get_driver(atombases) # setting positions signals successful setup; save current positions to # catch new positions and run the required re-setup of the driver self._positions = positions.detach().clone()
@overload def _get_driver( self, atombases: ListAtomCGTOBasis ) -> libcint.LibcintWrapper: ... @overload def _get_driver( self, atombases: list[ListAtomCGTOBasis], ) -> list[libcint.LibcintWrapper]: ... def _get_driver( self, atombases: ListAtomCGTOBasis | list[ListAtomCGTOBasis], ) -> libcint.LibcintWrapper | list[libcint.LibcintWrapper]: """ Wrapper for getting the `libcint` driver in single or batched mode. """ # pylint: disable=import-outside-toplevel from dxtb._src.exlibs import libcint if self.ihelp.batch_mode == 0: assert is_basis_list(atombases) return libcint.LibcintWrapper(atombases, self.ihelp) # integrals do not work with a batched IndexHelper if self.ihelp.batch_mode == 1: # pylint: disable=import-outside-toplevel from tad_mctc.batch import deflate _ihelp = [ IndexHelper.from_numbers(deflate(number), self.par) for number in self.numbers ] elif self.ihelp.batch_mode == 2: _ihelp = [ IndexHelper.from_numbers(number, self.par) for number in self.numbers ] else: raise ValueError(f"Unknown batch mode '{self.ihelp.batch_mode}'.") assert isinstance(atombases, list) return [ libcint.LibcintWrapper(ab, ihelp) for ab, ihelp in zip(atombases, _ihelp) if is_basis_list(ab) ]
[docs] class IntDriverLibcint(BaseIntDriverLibcint): """ Implementation of ``libcint``-based integral driver. """