Source code for dxtb._src.integral.factory

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Factories
=========

Factory functions for integral classes.
"""

# pylint: disable=import-outside-toplevel

from __future__ import annotations

import torch

from dxtb import IndexHelper
from dxtb._src.constants import labels
from dxtb._src.param import Param, ParamModule
from dxtb._src.typing import TYPE_CHECKING, Any, Tensor

if TYPE_CHECKING:
    from dxtb._src.xtb.gfn1 import GFN1Hamiltonian
    from dxtb._src.xtb.gfn2 import GFN2Hamiltonian

    from .driver.libcint import DipoleLibcint, OverlapLibcint, QuadrupoleLibcint
    from .driver.pytorch import DipolePytorch, OverlapPytorch, QuadrupolePytorch
    from .types import DipoleIntegral, OverlapIntegral, QuadrupoleIntegral

__all__ = ["new_hcore", "new_overlap", "new_dipint", "new_quadint"]


################################################################################


[docs] def new_hcore( numbers: Tensor, par: Param | ParamModule, ihelp: IndexHelper, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> GFN1Hamiltonian | GFN2Hamiltonian: """Create Core Hamiltonian instance based on parametrization.""" if not isinstance(par, ParamModule): par = ParamModule(par, device=device, dtype=dtype) if par.is_none("meta"): raise ValueError( "The `meta` information field is missing in the parametrization. " "No xTB core Hamiltonian can be selected and instantiated." ) if par.is_none("meta.name"): raise ValueError( "The `name` field of the meta information is missing in the " "parametrization. No xTB core Hamiltonian can be selected and " "instantiated." ) if par.meta.name.casefold() in ("gfn1-xtb", "gfn1"): return new_hcore_gfn1(numbers, ihelp, par, device=device, dtype=dtype) if par.meta.name.casefold() in ("gfn2-xtb", "gfn2"): return new_hcore_gfn2(numbers, ihelp, par, device=device, dtype=dtype) raise ValueError(f"Unsupported Hamiltonian type: {par.meta.name}")
def new_hcore_gfn1( numbers: Tensor, ihelp: IndexHelper, par: Param | ParamModule | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> GFN1Hamiltonian: """Create GFN1 Core Hamiltonian instance.""" from dxtb._src.xtb.gfn1 import GFN1Hamiltonian as Hamiltonian if par is None: from dxtb import GFN1_XTB as par return Hamiltonian(numbers, par, ihelp, device=device, dtype=dtype) def new_hcore_gfn2( numbers: Tensor, ihelp: IndexHelper, par: Param | ParamModule | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> GFN2Hamiltonian: """Create GFN2 Core Hamiltonian instance.""" from dxtb._src.xtb.gfn2 import GFN2Hamiltonian as Hamiltonian if par is None: from dxtb import GFN2_XTB as par return Hamiltonian(numbers, par, ihelp, device=device, dtype=dtype) ################################################################################
[docs] def new_overlap( driver: int = labels.INTDRIVER_LIBCINT, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> OverlapIntegral: """Create overlap integral instance.""" # Determine which integral class to instantiate based on the type if driver == labels.INTDRIVER_LIBCINT: return new_overlap_libcint(device=device, dtype=dtype, **kwargs) if driver in ( labels.INTDRIVER_ANALYTICAL, labels.INTDRIVER_AUTOGRAD, labels.INTDRIVER_LEGACY, ): return new_overlap_pytorch(device=device, dtype=dtype, **kwargs) raise ValueError(f"Unknown integral driver '{driver}'.")
def new_overlap_libcint( device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> OverlapLibcint: """Create libcint-based overlap integral.""" from .driver.libcint import OverlapLibcint as Overlap if kwargs.pop("force_cpu_for_libcint", True): device = torch.device("cpu") return Overlap(device=device, dtype=dtype, **kwargs) def new_overlap_pytorch( device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> OverlapPytorch: """Create PyTorch-based overlap integral.""" from .driver.pytorch import OverlapPytorch as Overlap return Overlap(device=device, dtype=dtype, **kwargs) ################################################################################
[docs] def new_dipint( driver: int = labels.INTDRIVER_LIBCINT, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> DipoleIntegral: """Create dipole integral instance.""" # Determine which integral class to instantiate based on the type if driver == labels.INTDRIVER_LIBCINT: return new_dipint_libcint(device=device, dtype=dtype, **kwargs) if driver in ( labels.INTDRIVER_ANALYTICAL, labels.INTDRIVER_AUTOGRAD, labels.INTDRIVER_LEGACY, ): return new_dipint_pytorch(device=device, dtype=dtype, **kwargs) raise ValueError(f"Unknown integral driver '{driver}'.")
def new_dipint_libcint( device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> DipoleLibcint: """Create libcint-based dipole integral.""" from .driver.libcint import DipoleLibcint as _Dipole if kwargs.pop("force_cpu_for_libcint", True): device = torch.device("cpu") return _Dipole(device=device, dtype=dtype, **kwargs) def new_dipint_pytorch( device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> DipolePytorch: """Create PyTorch-based dipole integral.""" from .driver.pytorch import DipolePytorch as _Dipole return _Dipole(device=device, dtype=dtype, **kwargs) ################################################################################
[docs] def new_quadint( driver: int = labels.INTDRIVER_LIBCINT, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> QuadrupoleIntegral: """Create quadrupole integral instance.""" # Determine which integral class to instantiate based on the type if driver == labels.INTDRIVER_LIBCINT: return new_quadint_libcint(device=device, dtype=dtype, **kwargs) if driver in ( labels.INTDRIVER_ANALYTICAL, labels.INTDRIVER_AUTOGRAD, labels.INTDRIVER_LEGACY, ): return new_quadint_pytorch(device=device, dtype=dtype, **kwargs) raise ValueError(f"Unknown integral driver '{driver}'.")
def new_quadint_libcint( device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> QuadrupoleLibcint: """Create libcint-based quadrupole integral.""" from .driver.libcint import QuadrupoleLibcint as Quadrupole if kwargs.pop("force_cpu_for_libcint", True): device = torch.device("cpu") return Quadrupole(device=device, dtype=dtype, **kwargs) def new_quadint_pytorch( device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ) -> QuadrupolePytorch: """Create PyTorch-based quadrupole integral.""" from .driver.pytorch import QuadrupolePytorch as Quadrupole return Quadrupole(device=device, dtype=dtype, **kwargs)