# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Wavefunction: Filling
=====================
Handle the occupation of the orbitals with electrons.
Parts of the Fermi smearing are taken from https://github.com/tbmalt/tbmalt
"""
from __future__ import annotations
import torch
from tad_mctc.convert import any_to_tensor
from dxtb._src.typing import DD, Tensor
__all__ = [
"get_alpha_beta_occupation",
"get_aufbau_occupation",
"get_fermi_energy",
"get_fermi_occupation",
]
[docs]
def get_alpha_beta_occupation(
nel: Tensor, uhf: Tensor | float | int | list[int] | None = None
) -> Tensor:
"""
Generate alpha and beta electrons from total number of electrons.
Parameters
----------
nel : Tensor
Total number of electrons.
uhf : Tensor | int | list[int] | None
Number of unpaired electrons. If ``None``, spin is figured out
automatically.
Returns
-------
Tensor
Alpha (first column, 0 index) and beta (second column, 1 index)
electrons.
Raises
------
ValueError
Number of electrons and unpaired electrons does not match.
Note
----
The number of electrons is rounded to integers via `torch.round` for
numerical stability, i.e., non-integer electrons are not supported.
"""
if uhf is not None:
if isinstance(uhf, (list, int, float)):
uhf = torch.tensor(uhf, device=nel.device, dtype=nel.dtype)
else:
uhf = uhf.type(nel.dtype).to(nel.device)
if uhf.shape != nel.shape:
raise RuntimeError(
f"Shape mismatch for unpaired electrons ({uhf.shape}) and "
f"number of electrons ({nel.shape})."
)
if (uhf > nel.round()).any():
raise ValueError(
f"Number of unpaired electrons ({uhf}) larger than "
f"number of electrons ({nel})."
)
# odd/even spin and even/odd number of electrons
if (torch.remainder(uhf, 2) != torch.remainder(nel.round(), 2)).any():
raise ValueError(
f"Odd (even) number of unpaired electrons ({uhf}) but even "
f"(odd) number of electrons ({nel}) given."
)
else:
# set to zero and figure out via remainder
uhf = torch.zeros_like(nel)
nel = torch.atleast_1d(nel)
uhf = torch.atleast_1d(uhf)
assert isinstance(uhf, Tensor)
nuhf = torch.where(
torch.remainder(uhf, 2) == torch.remainder(nel.round(), 2),
uhf,
torch.remainder(nel.round(), 2),
)
diff = torch.minimum(nuhf, nel)
nb = (nel - diff) / 2.0
na = nb + diff
return torch.cat([na, nb], dim=-1)
[docs]
def get_aufbau_occupation(norb: Tensor, nel: Tensor) -> Tensor:
"""
Set occupation numbers according to the aufbau principle.
The number of electrons is a real number and can be fractional.
Parameters
----------
norb : Tensor
Number of available orbitals.
nel : Tensor
Number of electrons.
Returns
-------
Tensor
Occupation numbers.
Examples
--------
>>> get_aufbau_occupation(torch.tensor(5), torch.tensor(1.))
tensor([1., 0., 0., 0., 0.])
>>> get_aufbau_occupation(torch.tensor([8, 8, 5]), torch.tensor([2., 3., 1.]))
tensor([[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0.]])
>>> nel, norb = torch.tensor([2.0, 3.5, 1.5]), torch.tensor([4, 4, 2])
>>> occ = get_aufbau_occupation(norb, nel)
>>> occ
tensor([[1.0000, 1.0000, 0.0000, 0.0000],
[1.0000, 1.0000, 1.0000, 0.5000],
[1.0000, 0.5000, 0.0000, 0.0000]])
>>> all(nel == occ.sum(-1))
True
.. code-block:: python
import torch
from dxtb.wavefunction import get_aufbau_occupation
# 1 electron in 5 orbitals
r1 = get_aufbau_occupation(torch.tensor(5), torch.tensor(1.))
print(r1)
# Output: tensor([1., 0., 0., 0., 0.])
# Multiple orbitals and different electron counts
r2 = get_aufbau_occupation(
torch.tensor([8, 8, 5]), torch.tensor([2., 3., 1.])
)
print(r2)
# Output: tensor([[1., 1., 0., 0., 0., 0., 0., 0.],
# [1., 1., 1., 0., 0., 0., 0., 0.],
# [1., 0., 0., 0., 0., 0., 0., 0.]])
# Fractional electron numbers in multiple orbitals
nel, norb = torch.tensor([2.0, 3.5, 1.5]), torch.tensor([4, 4, 2])
occ = get_aufbau_occupation(norb, nel)
print(occ)
# Output: tensor([[1.0000, 1.0000, 0.0000, 0.0000],
# [1.0000, 1.0000, 1.0000, 0.5000],
# [1.0000, 0.5000, 0.0000, 0.0000]])
# Check if the total number of electrons matches the sum of occupation
print(all(nel == occ.sum(-1))) # True
"""
# We represent the aufbau filling with a heaviside function, using the following steps
# 1. creating orbital indices using arange from 1 to norb, inclusively
idxs = torch.arange(1, 1 + torch.max(norb).item(), device=nel.device)
occupation = torch.heaviside(
# 2. remove the orbital index from the total number of electrons
# (negative numbers are filled with ones, positive numbers with zeros)
# 3. fractional occupation will be in the range [-1, 0], therefore we round up
torch.ceil(nel.unsqueeze(-1) - idxs.unsqueeze(-2)),
# 4. heaviside uses the actual values at 0, therefore we provide the remainder
# 5. to not lose whole electrons we take the negative and add one
torch.remainder(nel, -1).unsqueeze(-1) + 1,
)
return occupation.flatten() if nel.dim() == 0 else occupation
[docs]
def get_fermi_energy(
nel: Tensor, emo: Tensor, mask: Tensor | None = None
) -> tuple[Tensor, Tensor]:
"""
Get Fermi energy as midpoint between the HOMO and LUMO.
The orbital energies `emo` and the `mask` must already have the correct
shape for using alpha/beta electron channels. Spreading to the channels can
be done with `x.unsqueeze(-2).expand([*nel.shape, -1])`.
Parameters
----------
nel : Tensor
Number of electrons.
emo : Tensor
Orbital energies
mask : Tensor | None, optional
Mask from orbitals to avoid reading padding as LUMO for elements
without LUMO due to minimal basis.
Returns
-------
tuple[Tensor, Tensor]
Fermi energy and index of HOMO.
"""
zero = torch.tensor(0.0, device=emo.device, dtype=emo.dtype)
occ = torch.ones_like(emo)
occ_cs = occ.cumsum(-1) - nel.unsqueeze(-1)
# transition: negative values indicate end of occupied orbitals
temp = occ_cs >= (-torch.finfo(emo.dtype).resolution * 5)
# index of first non-negative value and unsqueeze for stacking;
# stacking will happen along that dim
homo = torch.argmax(temp.type(torch.long), dim=-1).unsqueeze(-1)
# some atoms (e.g., He) do not have a LUMO because of the valence basis and
# the LUMO index becomes larger than No. MOs
lumo_missing = occ.sum(-1, keepdim=True) - 1 <= homo
gap = torch.where(
lumo_missing,
torch.cat((homo, homo), -1), # Fermi energy becomes HOMO energy
torch.cat((homo, homo + 1), -1),
)
# Fermi energy as midpoint between HOMO and LUMO
e_fermi = torch.where(
nel != 0, # detect empty beta channel
torch.gather(emo, -1, gap).mean(-1),
zero, # no electrons yield Fermi energy of 0.0
)
# NOTE:
# In batched calculations, the missing LUMO is replaced by padding, which is
# not caught by the above `torch.where`. Consequently, the LUMO is 0.0 and
# the Fermi energy is exactly half of the correct value. To fix this, a mask
# from the orbitals of the IndexHelper is gathered in the same way as the
# Fermi energy. The `prod(-1)` reduces the dimension as `mean(-1)` does.
# Finally, multiplication by two corrects the mean, taken with E_LUMO = 0.
if mask is not None:
mask = torch.where(mask == 0, mask, torch.ones_like(mask))
mask = torch.gather(mask, -1, gap).prod(-1)
e_fermi = torch.where(mask != 0, e_fermi, e_fermi * 2.0)
return e_fermi, homo
[docs]
def get_fermi_occupation(
nel: Tensor,
emo: Tensor,
kt: Tensor,
mask: Tensor | None = None,
thr: Tensor | float | int | None = None,
maxiter: int = 200,
) -> Tensor:
"""
Set occupation numbers according to Fermi distribution.
The orbital energies `emo` must already have the correct shape for using
alpha/beta electron channels. Spreading to the channels can be done with
`emo.unsqueeze(-2).expand([*nel.shape, -1])`.
Parameters
----------
nel : Tensor
Number of electrons.
emo : Tensor
Orbital energies.
kt : Tensor
Electronic temperature in atomic units.
mask : Tensor | None, optional
Mask for Fermi energy. Just passed through.
thr : Tensor | None, optional
Threshold for converging Fermi energy, by default None.
maxiter : int, optional
Maximum number of iterations for converging Fermi energy.
Defaults to 200.
Returns
-------
Tensor
Occupation numbers.
Raises
------
RuntimeError
Fermi energy fails to converge.
TypeError
Electronic temperature is not given as `Tensor`.
ValueError
Electronic temperature is negative or number of electrons is zero.
"""
# wrong type of kt
if not isinstance(kt, Tensor) and kt is not None:
raise TypeError("Electronic temperature must be `Tensor` or ``None``.")
# negative etemp
if kt is not None and torch.any(kt < 0.0):
raise ValueError(
f"Electronic Temperature must be positive or None ({kt})."
)
dd: DD = {"device": emo.device, "dtype": emo.dtype}
eps = torch.tensor(torch.finfo(emo.dtype).eps, **dd)
zero = torch.tensor(0.0, **dd)
# no valence electrons
if (torch.abs(nel.sum(-1)) < eps).any():
return torch.zeros_like(emo)
if thr is None:
thr = torch.tensor(torch.finfo(emo.dtype).eps, **dd) ** 0.5
thresh = any_to_tensor(thr, **dd)
e_fermi, homo = get_fermi_energy(nel, emo, mask=mask)
# `emo` ([b, 2, n]) was expanded to second dim (for alpha/beta electrons)
# and we need to add a dim to `e_fermi` for subtraction in that dim
e_fermi = e_fermi.view([*nel.shape, -1]) # [b, 2, 1]
# check if (beta) channel contains electrons
not_empty = nel.unsqueeze(-1) != 0
emo = torch.where(not_empty, emo, zero)
# iterate fermi energy
for _ in range(maxiter):
exponent = (emo - e_fermi) / kt
eterm = torch.exp(torch.where(exponent < 50, exponent, zero))
# only singly occupied here v
fermi = torch.where(exponent < 50, 1.0 / (eterm + 1.0), zero)
dfermi = torch.where(
exponent < 50, eterm / (kt * (eterm + 1.0) ** 2), eps
)
_nel = torch.sum(fermi, dim=-1, keepdim=True)
change = (homo - _nel + 1) / torch.sum(dfermi, dim=-1, keepdim=True)
e_fermi += change
if torch.all(torch.abs(homo - _nel + 1) <= thresh):
# check if beta channel is empty
return torch.where(not_empty, fermi, torch.zeros_like(fermi))
raise RuntimeError("Fermi energy failed to converge.")