Source code for dxtb._src.basis.indexhelper

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Basis: IndexHelper
==================

Index helper utility to create index maps between atomic, shell-resolved, and
orbital resolved representations of quantities.

Example
-------

.. code-block:: python

    import torch
    from dxtb import IndexHelper

    # Define atomic numbers and angular momentum for each element
    numbers = torch.tensor([6, 1, 1, 1, 1])
    angular = {1: [0], 6: [0, 1]}

    # Create an IndexHelper instance with angular momentum specifications
    ihelp = IndexHelper.from_numbers_angular(numbers, angular)

    # Count the number of entries in the angular momentum tensor
    result = torch.sum(ihelp.angular >= 0)
    print(result)  # torch.tensor(6)
"""

from __future__ import annotations

import torch
from tad_mctc.batch import pack
from tad_mctc.math import einsum

from dxtb._src.typing import Slicers, Tensor, TensorLike, override

from ..param import Param, ParamModule
from ..utils import t2int, wrap_gather, wrap_scatter_reduce

__all__ = ["IndexHelper"]


PAD = -999


def _fill(index: Tensor, repeat: Tensor) -> Tensor:
    """
    Fill an index map using index offsets and number of repeats
    """
    index_map = torch.zeros(
        int(torch.sum(repeat).item()), device=index.device, dtype=index.dtype
    )

    for idx, offset, count in zip(torch.arange(index.shape[-1]), index, repeat):
        index_map[offset : offset + count] = idx
    return index_map


def _expand(index: Tensor, repeat: Tensor) -> Tensor:
    """
    Expand an index map using index offsets and number of repeats
    """

    return torch.tensor(
        [
            idx
            for offset, count in zip(index, repeat)
            for idx in torch.arange(offset.item(), (offset + count).item(), 1)
        ],
        device=index.device,
        dtype=index.dtype,
    )


class IndexHelperStore:
    """
    Storage container for IndexHelper containing ``__slots__`` before culling.
    """

    def __init__(
        self,
        unique_angular: Tensor,
        angular: Tensor,
        atom_to_unique: Tensor,
        ushells_to_unique: Tensor,
        shells_to_ushell: Tensor,
        shells_per_atom: Tensor,
        shell_index: Tensor,
        shells_to_atom: Tensor,
        orbitals_per_shell: Tensor,
        orbital_index: Tensor,
        orbitals_to_shell: Tensor,
    ):
        self.unique_angular = unique_angular
        self.angular = angular
        self.atom_to_unique = atom_to_unique
        self.ushells_to_unique = ushells_to_unique
        self.shells_to_ushell = shells_to_ushell
        self.shells_per_atom = shells_per_atom
        self.shell_index = shell_index
        self.shells_to_atom = shells_to_atom
        self.orbitals_per_shell = orbitals_per_shell
        self.orbital_index = orbital_index
        self.orbitals_to_shell = orbitals_to_shell


[docs] class IndexHelper(TensorLike): """ Index helper for basis set. """ unique_angular: Tensor """Angular momenta of all unique shells""" angular: Tensor """Angular momenta for all shells""" atom_to_unique: Tensor """Mapping of atoms to unique species""" ushells_to_unique: Tensor """Mapping of unique shells to unique species""" ushells_per_unique: Tensor """Number of unique shells per unqiue atoms.""" shells_to_ushell: Tensor """Mapping of shells to unique atoms""" shells_per_atom: Tensor # nsh_at """Number of shells for each atom""" orbitals_per_shell: Tensor """Number of orbitals for each shell""" shell_index: Tensor # ish_at """Offset index for starting the next shell block""" orbital_index: Tensor """Offset index for starting the next orbital block""" shells_to_atom: Tensor """Mapping of shells to atoms""" orbitals_to_shell: Tensor """Mapping of orbitals to shells""" batch_mode: int """ Whether multiple systems or a single one are handled: - 0: Single system - 1: Multiple systems with padding - 2: Multiple systems with no padding (conformer ensemble) """ store: IndexHelperStore | None """Storage to restore from after culling.""" __slots__ = [ "unique_angular", "angular", "atom_to_unique", "ushells_to_unique", "ushells_per_unique", "shells_to_ushell", "shells_per_atom", "shell_index", "shells_to_atom", "orbitals_per_shell", "orbital_index", "orbitals_to_shell", "batch_mode", "store", ] def __init__( self, unique_angular: Tensor, angular: Tensor, atom_to_unique: Tensor, ushells_to_unique: Tensor, ushells_per_unique: Tensor, shells_to_ushell: Tensor, shells_per_atom: Tensor, shell_index: Tensor, shells_to_atom: Tensor, orbitals_per_shell: Tensor, orbital_index: Tensor, orbitals_to_shell: Tensor, batch_mode: int, device: torch.device | None = None, dtype: torch.dtype = torch.int64, *, store: IndexHelperStore | None = None, **_, ): super().__init__(device, dtype) # Dependent memoization causes memory leaks. The tensors will remain in # the cache object and cannot be garbage collected. Only if the clearing # function below is called, the tensors are removed. Note that this # works across different instances of the IndexHelper, as the cache for # memoization is designed in a cross instance fashion. # This might lead to unexpected behavior, which was detected by the # memory leak tests: The cache was still populated from another test # and only after instantiation of the IndexHelper in the memory leak # test, `self.clear_cache()` was called and the tensors where removed. # Hence, before instantiation more tensors are in memory than after, # which is actually the opposite of what the memory leak tests were # designed for. # self.clear_cache() self.unique_angular = unique_angular self.angular = angular self.atom_to_unique = atom_to_unique self.ushells_to_unique = ushells_to_unique self.ushells_per_unique = ushells_per_unique self.shells_to_ushell = shells_to_ushell self.shells_per_atom = shells_per_atom self.shell_index = shell_index self.shells_to_atom = shells_to_atom self.orbitals_per_shell = orbitals_per_shell self.orbital_index = orbital_index self.orbitals_to_shell = orbitals_to_shell self.batch_mode = batch_mode self.store = store if any( tensor.dtype != self.dtype for tensor in ( self.unique_angular, self.angular, self.atom_to_unique, self.ushells_to_unique, self.ushells_per_unique, self.shells_to_ushell, self.shells_per_atom, self.shell_index, self.shells_to_atom, self.orbitals_per_shell, self.orbital_index, self.orbitals_to_shell, ) ): raise ValueError("All tensors must have same dtype") if any( tensor.device != self.device for tensor in ( self.unique_angular, self.angular, self.atom_to_unique, self.ushells_to_unique, self.ushells_per_unique, self.shells_to_ushell, self.shells_per_atom, self.shell_index, self.shells_to_atom, self.orbitals_per_shell, self.orbital_index, self.orbitals_to_shell, ) ): raise ValueError("All tensors must be on the same device")
[docs] @classmethod def from_numbers( cls, numbers: Tensor, par: Param | ParamModule, batch_mode: int | None = None, move_to_numbers_device: bool = True, ) -> IndexHelper: """ Construct an index helper instance from atomic numbers and a parametrization. Note that this always runs on CPU to avoid inefficient communication between devices. Only the resulting tensors are transfered to the GPU. This is necessary because of complex data look up that is not vectorizable and requires native for-loops. Furthermore, the method frequently uses the :meth:`torch.Tensor.item` method, which forces CPU-GPU synchronization because it converts a GPU tensor to a Python scalar. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). par : Param Representation of an extended tight-binding model. batch_mode : int Whether multiple systems or a single one are handled: - 0: Single system - 1: Multiple systems with padding - 2: Multiple systems with no padding (conformer ensemble) move_to_numbers_device : bool Move the resulting tensors to the device of the ``numbers`` tensor. This should be switched off for GPU calculations that use `libcint` for integrals as the :class:`.IndexHelper` has to be on the CPU for this step. Returns ------- IndexHelper Instance of index helper for given basis set. """ if not isinstance(par, ParamModule): par = ParamModule(par) return cls.from_numbers_angular( numbers, angular=par.get_elem_angular(), batch_mode=batch_mode, move_to_numbers_device=move_to_numbers_device, )
[docs] @classmethod def from_numbers_angular( cls, numbers: Tensor, angular: dict[int, list[int]], batch_mode: int | None = None, move_to_numbers_device: bool = True, ) -> IndexHelper: """ Construct an index helper instance from atomic numbers and their angular momenta. If you are not sure about the angular momenta, use :meth:`.IndexHelper.from_numbers` instead, which simply takes a parametrization. Note that this always runs on CPU to avoid inefficient communication between devices. Only the resulting tensors are transfered to the GPU. This is necessary because of complex data look up that is not vectorizable and requires native for-loops. Furthermore, the method frequently uses the :meth:`torch.Tensor.item` method, which forces CPU-GPU synchronization because it converts a GPU tensor to a Python scalar. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). angular : dict[int, Tensor] Map between atomic numbers and angular momenta of all shells. batch_mode : int Whether multiple systems or a single one are handled: - 0: Single system - 1: Multiple systems with padding - 2: Multiple systems with no padding (conformer ensemble) move_to_numbers_device : bool Move the resulting tensors to the device of the ``numbers`` tensor. This should be switched off for GPU calculations that use `libcint` for integrals as the :class:`.IndexHelper` has to be on the CPU for this step. Returns ------- IndexHelper Instance of index helper for given basis set. """ cpu = torch.device("cpu") device = numbers.device if move_to_numbers_device else cpu # Ensure that all tensors are moved to CPU to avoid inefficient # memory transfers between devices (.item() and native for-loops). numbers = numbers.to(cpu) if batch_mode is None: batch_mode = numbers.ndim > 1 unique, atom_to_unique = torch.unique(numbers, return_inverse=True) unique_angular = torch.tensor( [l for number in unique for l in angular.get(number.item(), [-1])], device=cpu, ) # note that padding (i.e., when number = 0) is assigned one shell ushells_per_unique = torch.tensor( [len(angular.get(number.item(), [-1])) for number in unique], device=cpu, ) ushell_index = ( torch.cumsum(ushells_per_unique, dim=-1) - ushells_per_unique ) ushells_to_unique = _fill(ushell_index, ushells_per_unique) if batch_mode > 0: # remove the single shell assigned to the padding value in order to # avoid an additional count in the expansion as this will cause # errors in certain situations # (see https://github.com/grimme-lab/dxtb/issues/67) if (unique == 0.0).any(): ushells_per_unique[0] = 0 shells_to_ushell = pack( [ _expand( ushell_index[atom_to_unique[_batch, :]], ushells_per_unique[atom_to_unique[_batch, :]], ) for _batch in range(numbers.shape[0]) ], value=-1, ) else: shells_to_ushell = _expand( ushell_index[atom_to_unique], ushells_per_unique[atom_to_unique], ) shells_per_atom = ushells_per_unique[atom_to_unique] shell_index = torch.cumsum(shells_per_atom, -1) - shells_per_atom shell_index[shells_per_atom == 0] = PAD if batch_mode > 0: shells_to_atom = pack( [ _fill(shell_index[_batch, :], shells_per_atom[_batch, :]) for _batch in range(numbers.shape[0]) ], value=PAD, ) else: shells_to_atom = _fill(shell_index, shells_per_atom) lsh = torch.where( shells_to_ushell >= 0, unique_angular[shells_to_ushell], PAD, ) orbitals_per_shell = torch.where( lsh >= 0, 2 * lsh + 1, torch.tensor(0, device=cpu) ) orbital_index = ( torch.cumsum(orbitals_per_shell, -1) - orbitals_per_shell ) orbital_index[orbitals_per_shell == 0] = PAD if batch_mode > 0: orbitals_to_shell = pack( [ _fill( orbital_index[_batch, :], orbitals_per_shell[_batch, :] ) for _batch in range(numbers.shape[0]) ], value=PAD, ) else: orbitals_to_shell = _fill(orbital_index, orbitals_per_shell) return cls( unique_angular=unique_angular.to(device), angular=lsh.to(device), atom_to_unique=atom_to_unique.to(device), ushells_to_unique=ushells_to_unique.to(device), ushells_per_unique=ushells_per_unique.to(device), shells_to_ushell=shells_to_ushell.to(device), shells_per_atom=shells_per_atom.to(device), shell_index=shell_index.to(device), shells_to_atom=shells_to_atom.to(device), orbitals_per_shell=orbitals_per_shell.to(device), orbital_index=orbital_index.to(device), orbitals_to_shell=orbitals_to_shell.to(device), batch_mode=batch_mode, device=device, )
[docs] def reduce_orbital_to_shell( self, x: Tensor, dim: int | tuple[int, int] = -1, reduce: str = "sum", extra: bool = False, ) -> Tensor: """ Reduce orbital-resolved tensor to shell-resolved tensor. Parameters ---------- x : Tensor Orbital-resolved tensor. dim : int | (int, int) Dimension to reduce over, defaults to -1. reduce : str Reduction method, defaults to "sum". extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Shell-resolved tensor. """ return wrap_scatter_reduce( x, dim, self.orbitals_to_shell, reduce, extra=extra )
[docs] def reduce_shell_to_atom( self, x: Tensor, dim: int | tuple[int, int] = -1, reduce: str = "sum", extra: bool = False, ) -> Tensor: """ Reduce shell-resolved tensor to atom-resolved tensor. Parameters ---------- x : Tensor Shell-resolved tensor dim : int | (int, int) Dimension to reduce over, defaults to -1. reduce : str Reduction method, defaults to "sum". extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Atom-resolved tensor. """ return wrap_scatter_reduce( x, dim, self.shells_to_atom, reduce, extra=extra )
[docs] def reduce_orbital_to_atom( self, x: Tensor, dim: int | tuple[int, int] = -1, reduce: str = "sum", extra: bool = False, ) -> Tensor: """ Reduce orbital-resolved tensor to atom-resolved tensor. Parameters ---------- x : Tensor Orbital-resolved tensor. dim : int | (int, int) Dimension to reduce over, defaults to -1. reduce : str Reduction method, defaults to "sum". extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Atom-resolved tensor. """ return self.reduce_shell_to_atom( self.reduce_orbital_to_shell( x, dim=dim, reduce=reduce, extra=extra ), dim=dim, reduce=reduce, extra=extra, )
[docs] def spread_atom_to_shell( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False, ) -> Tensor: """ Spread atom-resolved tensor to shell-resolved tensor. Parameters ---------- x : Tensor Atom-resolved tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Shell-resolved tensor. """ return wrap_gather(x, dim, self.shells_to_atom, extra=extra)
[docs] def spread_shell_to_orbital( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False, ) -> Tensor: """ Spread shell-resolved tensor to orbital-resolved tensor. Parameters ---------- x : Tensor Shell-resolved tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Orbital-resolved tensor. """ return wrap_gather(x, dim, self.orbitals_to_shell, extra=extra)
[docs] def spread_shell_to_orbital_cart( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False, ) -> Tensor: """ Spread shell-resolved tensor to orbital-resolved tensor. Parameters ---------- x : Tensor Shell-resolved tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Orbital-resolved tensor. """ return wrap_gather(x, dim, self.orbitals_to_shell_cart, extra=extra)
[docs] def spread_atom_to_orbital( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False, ) -> Tensor: """ Spread atom-resolved tensor to orbital-resolved tensor. Parameters ---------- x : Tensor Atom-resolved tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Orbital-resolved tensor. """ return self.spread_shell_to_orbital( self.spread_atom_to_shell(x, dim=dim, extra=extra), dim=dim, extra=extra, )
[docs] def spread_atom_to_orbital_cart( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False ) -> Tensor: """ Spread atom-resolved tensor to orbital-resolved tensor. Parameters ---------- x : Tensor Atom-resolved tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Orbital-resolved tensor. """ return self.spread_shell_to_orbital_cart( self.spread_atom_to_shell(x, dim=dim, extra=extra), dim=dim, extra=extra, )
[docs] def spread_uspecies_to_atom( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False ) -> Tensor: """ Spread unique species tensor to atom-resolved tensor. Parameters ---------- x : Tensor Unique specie tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Atom-resolved tensor. """ return wrap_gather(x, dim, self.atom_to_unique, extra=extra)
[docs] def spread_uspecies_to_shell( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False ) -> Tensor: """ Spread unique species tensor to shell-resolved tensor. Parameters ---------- x : Tensor Unique specie tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Shell-resolved tensor. """ return self.spread_atom_to_shell( self.spread_uspecies_to_atom(x, dim=dim, extra=extra), dim=dim, extra=extra, )
[docs] def spread_uspecies_to_orbital( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False ) -> Tensor: """ Spread unique species tensor to orbital-resolved tensor. Parameters ---------- x : Tensor Unique specie tensor. dim : int Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Orbital-resolved tensor. """ return self.spread_atom_to_orbital( self.spread_uspecies_to_atom(x, dim=dim, extra=extra), dim=dim, extra=extra, )
[docs] def spread_uspecies_to_orbital_cart( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False ) -> Tensor: """ Spread unique species tensor to orbital-resolved tensor. Parameters ---------- x : Tensor Unique specie tensor. dim : int Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Orbital-resolved tensor. """ return self.spread_atom_to_orbital_cart( self.spread_uspecies_to_atom(x, dim=dim, extra=extra), dim=dim, extra=extra, )
[docs] def spread_ushell_to_shell( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False ) -> Tensor: """ Spread unique shell tensor to shell-resolved tensor. Parameters ---------- x : Tensor Unique shell tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Shell-resolved tensor. """ return wrap_gather(x, dim, self.shells_to_ushell, extra=extra)
[docs] def spread_ushell_to_orbital( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False ) -> Tensor: """ Spread unique shell tensor to orbital-resolved tensor. Parameters ---------- x : Tensor Unique shell tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Orbital-resolved tensor. """ return self.spread_shell_to_orbital( self.spread_ushell_to_shell(x, dim=dim, extra=extra), dim=dim, extra=extra, )
[docs] def spread_ushell_to_orbital_cart( self, x: Tensor, dim: int | tuple[int, int] = -1, extra: bool = False ) -> Tensor: """ Spread unique shell tensor to orbital-resolved tensor. Parameters ---------- x : Tensor Unique shell tensor. dim : int | (int, int) Dimension to spread over, defaults to -1. extra : bool Tensor to reduce contains a extra dimension of arbitrary size. Defaults to ``False``. Returns ------- Tensor Orbital-resolved tensor. """ return self.spread_shell_to_orbital_cart( self.spread_ushell_to_shell(x, dim=dim, extra=extra), dim=dim, extra=extra, )
[docs] def cull(self, conv: Tensor, slicers: Slicers) -> None: if self.batch_mode == 0: raise RuntimeError("Culling only possible in batch mode.") if self.store is None: self.store = IndexHelperStore( unique_angular=self.unique_angular, angular=self.angular, atom_to_unique=self.atom_to_unique, ushells_to_unique=self.ushells_to_unique, shells_to_ushell=self.shells_to_ushell, shells_per_atom=self.shells_per_atom, shell_index=self.shell_index, shells_to_atom=self.shells_to_atom, orbitals_per_shell=self.orbitals_per_shell, orbital_index=self.orbital_index, orbitals_to_shell=self.orbitals_to_shell, ) at = tuple([~conv, *slicers["atom"]]) sh = tuple([~conv, *slicers["shell"]]) orb = tuple([~conv, *slicers["orbital"]]) self.angular = self.angular[sh] self.atom_to_unique = self.atom_to_unique[at] self.shell_index = self.shell_index[at] self.shells_per_atom = self.shells_per_atom[at] self.shells_to_ushell = self.shells_to_ushell[sh] self.shells_to_atom = self.shells_to_atom[sh] self.orbitals_per_shell = self.orbitals_per_shell[sh] self.orbital_index = self.orbital_index[sh] self.orbitals_to_shell = self.orbitals_to_shell[orb]
[docs] def restore(self) -> None: """Restore the original index helper after culling.""" if self.store is None: raise RuntimeError("Nothing to restore. Store is empty.") self.angular = self.store.angular self.atom_to_unique = self.store.atom_to_unique self.shells_to_ushell = self.store.shells_to_ushell self.shells_per_atom = self.store.shells_per_atom self.shell_index = self.store.shell_index self.shells_to_atom = self.store.shells_to_atom self.orbitals_per_shell = self.store.orbitals_per_shell self.orbital_index = self.store.orbital_index self.orbitals_to_shell = self.store.orbitals_to_shell
@property def orbitals_to_atom(self) -> Tensor: return self._orbitals_to_atom() # @dependent_memoize(lambda self: self.shells_to_atom) def _orbitals_to_atom(self) -> Tensor: return self.spread_shell_to_orbital(self.shells_to_atom) @property def orbitals_per_shell_cart(self) -> Tensor: return self._orbitals_per_shell_cart() # @dependent_memoize(lambda self: self.angular) def _orbitals_per_shell_cart(self) -> Tensor: l = self.angular ls = torch.div((l + 1) * (l + 2), 2, rounding_mode="floor") return torch.where(l >= 0, ls, torch.tensor(0, device=self.device)) @property def orbital_index_cart(self) -> Tensor: return self._orbital_index_cart() # @dependent_memoize(lambda self: self.orbitals_per_shell_cart) def _orbital_index_cart(self) -> Tensor: orb_per_shell = self.orbitals_per_shell_cart return torch.cumsum(orb_per_shell, -1) - orb_per_shell @property def orbitals_to_shell_cart(self) -> Tensor: return self._orbitals_to_shell_cart() # @dependent_memoize( # lambda self: self.orbital_index_cart, # lambda self: self.orbitals_per_shell_cart, # ) def _orbitals_to_shell_cart(self) -> Tensor: orbital_index = self.orbital_index_cart orbitals_per_shell = self.orbitals_per_shell_cart if self.batch_mode > 0: orbitals_to_shell = pack( [ _fill( orbital_index[_batch, :], orbitals_per_shell[_batch, :] ) for _batch in range(self.angular.shape[0]) ], value=PAD, ) else: orbitals_to_shell = _fill(orbital_index, orbitals_per_shell) return orbitals_to_shell @property def orbitals_to_atom_cart(self) -> Tensor: return self._orbitals_to_atom_cart() # @dependent_memoize(lambda self: self.shells_to_atom) def _orbitals_to_atom_cart(self) -> Tensor: return self.spread_shell_to_orbital_cart(self.shells_to_atom) # def clear_cache(self) -> None: # """Clear the cross-instance caches of all memoized methods.""" # if hasattr(self._orbitals_per_shell_cart, "clear_cache"): # self._orbitals_per_shell_cart.clear_cache() # if hasattr(self._orbital_index_cart, "clear_cache"): # self._orbital_index_cart.clear_cache() # if hasattr(self._orbitals_to_shell_cart, "clear_cache"): # self._orbitals_to_shell_cart.clear_cache() # if hasattr(self._orbitals_to_atom_cart, "clear_cache"): # self._orbitals_to_atom_cart.clear_cache() # if hasattr(self._orbitals_to_atom, "clear_cache"): # self._orbitals_to_atom.clear_cache()
[docs] def get_shell_indices(self, atom_idx: int) -> Tensor: """ Get shell indices belong to given atom. Parameters ---------- atom_idx : int Index of given atom. Returns ------- Tensor Index list of shells belonging to given atom. """ return (self.shells_to_atom == atom_idx).nonzero(as_tuple=True)[0]
[docs] def get_orbital_indices(self, shell_idx: int) -> Tensor: """ Get orbital indices belong to given shell. Parameters ---------- shell_idx : int Index of given shell. Returns ------- Tensor Index list of orbitals belonging to given shell. """ return (self.orbitals_to_shell == shell_idx).nonzero(as_tuple=True)[0]
[docs] def orbital_atom_mapping(self, idx: int) -> Tensor: """ Mapping of atom index to orbital index, i.e., return indices of orbitals belonging to given atom. The orbital order is given by :meth:`.IndexHelper.orbitals_to_shell`. Parameters ---------- idx : int Index of target atom. Returns ------- Tensor 1d-Tensor containing the indices of the orbitals. """ # FIXME: batched mode if self.batch_mode > 0: raise NotImplementedError( "Currently, `orbital_atom_mapping` only supports a single sample." ) return torch.tensor( [ oidx for sidx in self.get_shell_indices(idx) for oidx in self.get_orbital_indices(t2int(sidx)).tolist() ] )
@property def orbitals_per_atom(self) -> Tensor: """ Number of orbitals for each atom. Returns ------- Tensor Atom indices for each orbital. """ try: # batch mode pad = torch.nn.utils.rnn.pad_sequence( [self.shells_to_atom.mT, self.orbitals_to_shell.T], padding_value=PAD, ) pad = einsum("ijk->kji", pad) # [2, bs, norb_max] except RuntimeError: # single mode pad = torch.nn.utils.rnn.pad_sequence( [self.shells_to_atom, self.orbitals_to_shell], padding_value=PAD ).T # [2, norb_max] if len(pad.shape) > 2: # gathering over subentries to avoid padded value (PAD) in index tensor return pack( [torch.gather(a[b != PAD], 0, b[b != PAD]) for a, b in pad], value=PAD, ) # TODO: # masked_tensor could be a vectorised solution (though only # available in pytorch 1.13) # alternatively write all values into extra column else: return torch.gather(pad[0], 0, pad[1]) @property def nat(self) -> int: return self.atom_to_unique.shape[-1] @property def nsh(self) -> int: return int(self.shells_per_atom.sum(-1).max()) @property def nao(self) -> int: return int(self.orbitals_per_shell.sum(-1).max()) @property def nbatch(self) -> int | None: return self.atom_to_unique.ndim if self.batch_mode > 0 else None @property def allowed_dtypes(self) -> tuple[torch.dtype, ...]: """ Specification of dtypes that the TensorLike object can take. Returns ------- tuple[torch.dtype, ...] Collection of allowed dtypes the TensorLike object can take. """ return (torch.int16, torch.int32, torch.int64, torch.long) def __str__(self) -> str: # pragma: no cover return ( f"IndexHelper(\n" f" unique_angular={self.unique_angular},\n" f" angular={self.angular},\n" f" atom_to_unique={self.atom_to_unique},\n" f" ushells_to_unique={self.ushells_to_unique},\n" f" ushells_per_unique={self.ushells_per_unique},\n" f" shells_to_ushell={self.shells_to_ushell},\n" f" shells_per_atom={self.shells_per_atom},\n" f" shell_index={self.shell_index},\n" f" shells_to_atom={self.shells_to_atom},\n" f" orbitals_per_shell={self.orbitals_per_shell},\n" f" orbital_index={self.orbital_index},\n" f" orbitals_to_shell={self.orbitals_to_shell},\n" f" batch_mode={self.batch_mode},\n" f" store={self.store},\n" f" device={self.device},\n" f" dtype={self.dtype}\n" ")" ) def __repr__(self) -> str: # pragma: no cover return str(self)
class IndexHelperGFN1(IndexHelper): """ Index helper for GFN1 basis set. """ @override @classmethod def from_numbers( cls, numbers: Tensor, batch_mode: int | None = None, move_to_numbers_device: bool = True, ) -> IndexHelper: """ Construct an index helper instance from atomic numbers and their angular momenta. The latter are collected from the GFN1 parametrization. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). batch_mode : int Whether multiple systems or a single one are handled: - 0: Single system - 1: Multiple systems with padding - 2: Multiple systems with no padding (conformer ensemble) move_to_numbers_device : bool Move the resulting tensors to the device of the ``numbers`` tensor. This should be switched off for GPU calculations that use `libcint` for integrals as the :class:`.IndexHelper` has to be on the CPU for this step. Returns ------- IndexHelper Instance of index helper for given basis set. """ # pylint: disable=import-outside-toplevel from dxtb import GFN1_XTB return super().from_numbers( numbers, GFN1_XTB, batch_mode=batch_mode, move_to_numbers_device=move_to_numbers_device, ) class IndexHelperGFN2(IndexHelper): """ Index helper for GFN2 basis set. """ @override @classmethod def from_numbers( cls, numbers: Tensor, batch_mode: int | None = None, move_to_numbers_device: bool = True, ) -> IndexHelper: """ Construct an index helper instance from atomic numbers and their angular momenta. The latter are collected from the GFN1 parametrization. Parameters ---------- numbers : Tensor Atomic numbers for all atoms in the system (shape: ``(..., nat)``). batch_mode : int Whether multiple systems or a single one are handled: - 0: Single system - 1: Multiple systems with padding - 2: Multiple systems with no padding (conformer ensemble) move_to_numbers_device : bool Move the resulting tensors to the device of the ``numbers`` tensor. This should be switched off for GPU calculations that use `libcint` for integrals as the :class:`.IndexHelper` has to be on the CPU for this step. Returns ------- IndexHelper Instance of index helper for given basis set. """ # pylint: disable=import-outside-toplevel from dxtb import GFN2_XTB return super().from_numbers( numbers, GFN2_XTB, batch_mode=batch_mode, move_to_numbers_device=move_to_numbers_device, )