# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import sys
from argparse import Namespace
from pathlib import Path
import torch
from dxtb._src.constants import defaults, labels
from dxtb._src.typing import (
Any,
PathLike,
Self,
get_default_device,
get_default_dtype,
)
from .cache import ConfigCache
from .integral import ConfigIntegrals
from .scf import ConfigSCF
__all__ = ["Config"]
[docs]
class Config:
"""
Configuration of the calculation.
"""
file: PathLike | None
"""The input file or directory."""
strict: bool = False
"""Strict mode for SCF configuration. Always throws errors if ``True``."""
exclude: str | list[str]
"""The tight-binding components to exclude from the calculation."""
method: int
"""The xTB method to use."""
grad: bool
"""Whether to compute the gradient."""
max_element: int
"""The maximum element number in the system."""
# PyTorch
anomaly: bool
"""Whether to run PyTorch in anomaly detection mode."""
device: torch.device
"""The device to use for the calculation."""
dtype: torch.dtype
"""The data type to use for the calculation."""
# configs
cache: ConfigCache
"""The cache configuration."""
ints: ConfigIntegrals
"""The integral configuration."""
scf: ConfigSCF
"""The SCF configuration."""
def __init__(
self,
*,
file=None,
strict: bool = defaults.STRICT,
exclude: str | list[str] = defaults.EXCLUDE,
method: str | int = defaults.METHOD,
grad: bool = False,
batch_mode: int = defaults.BATCH_MODE,
# integrals
int_cutoff: float = defaults.INTCUTOFF,
int_driver: str | int = defaults.INTDRIVER,
int_level: int = defaults.INTLEVEL,
int_uplo: str = defaults.INTUPLO,
# PyTorch
anomaly: bool = False,
device: torch.device = get_default_device(),
dtype: torch.dtype = get_default_dtype(),
# SCF
maxiter: int = defaults.MAXITER,
mixer: str | int = defaults.MIXER,
mix_guess: bool = defaults.MIX_GUESS,
damp: float = defaults.DAMP,
damp_init: float = defaults.DAMP_INIT,
damp_dynamic: bool = defaults.DAMP_DYNAMIC,
damp_dynamic_factor: float = defaults.DAMP_DYNAMIC_FACTOR,
damp_soft_start: bool = defaults.DAMP_SOFT_START,
damp_generations: int = defaults.DAMP_GENERATIONS,
damp_diagonal_offset: float = defaults.DAMP_DIAGONAL_OFFSET,
guess: str | int = defaults.GUESS,
scf_mode: str | int = defaults.SCF_MODE,
scp_mode: str | int = defaults.SCP_MODE,
x_atol: float = defaults.X_ATOL,
x_atol_max: float = defaults.X_ATOL_MAX,
f_atol: float = defaults.F_ATOL,
force_convergence: bool = False,
fermi_etemp: float = defaults.FERMI_ETEMP,
fermi_maxiter: int = defaults.FERMI_MAXITER,
fermi_thresh: float | int | None = defaults.FERMI_THRESH,
fermi_partition: str | int = defaults.FERMI_PARTITION,
# cache
cache_enabled: bool = defaults.CACHE_ENABLED,
cache_hcore: bool = defaults.CACHE_STORE_HCORE,
cache_overlap: bool = defaults.CACHE_STORE_OVERLAP,
cache_dipole: bool = defaults.CACHE_STORE_DIPOLE,
cache_quadrupole: bool = defaults.CACHE_STORE_QUADRUPOLE,
cache_charges: bool = defaults.CACHE_STORE_CHARGES,
cache_coefficients: bool = defaults.CACHE_STORE_COEFFICIENTS,
cache_density: bool = defaults.CACHE_STORE_DENSITY,
cache_fock: bool = defaults.CACHE_STORE_FOCK,
cache_iterations: bool = defaults.CACHE_STORE_ITERATIONS,
cache_mo_energies: bool = defaults.CACHE_STORE_MO_ENERGIES,
cache_occupation: bool = defaults.CACHE_STORE_OCCUPATIONS,
cache_potential: bool = defaults.CACHE_STORE_POTENTIAL,
# misc
max_element: int = defaults.MAX_ELEMENT,
skip_compat_checks: bool = False,
) -> None:
self.file = file
self.strict = strict
self.exclude = exclude
self.grad = grad
self.anomaly = anomaly
self.device = device
self.dtype = dtype
# use property to also set the batch mode in SCF config
self._batch_mode = batch_mode
self.max_element = max_element
if isinstance(method, str):
if method.casefold() in labels.GFN1_XTB_STRS:
self.method = labels.GFN1_XTB
elif method.casefold() in labels.GFN2_XTB_STRS:
self.method = labels.GFN2_XTB
else:
raise ValueError(f"Unknown xtb method '{method}'.")
elif isinstance(method, int):
if method not in (labels.GFN1_XTB, labels.GFN2_XTB):
raise ValueError(f"Unknown xtb method '{method}'.")
self.method = method
else:
raise TypeError(
"The method must be of type 'int' or 'str', but "
f"'{type(method)}' was given."
)
self.cache = ConfigCache(
enabled=cache_enabled,
#
hcore=cache_hcore,
overlap=cache_overlap,
dipole=cache_dipole,
quadrupole=cache_quadrupole,
#
charges=cache_charges,
coefficients=cache_coefficients,
density=cache_density,
fock=cache_fock,
iterations=cache_iterations,
mo_energies=cache_mo_energies,
occupation=cache_occupation,
potential=cache_potential,
)
self.ints = ConfigIntegrals(
level=int_level,
cutoff=int_cutoff,
driver=int_driver,
uplo=int_uplo,
)
self.scf = ConfigSCF(
strict=strict,
method=self.method,
guess=guess,
maxiter=maxiter,
mixer=mixer,
mix_guess=mix_guess,
damp=damp,
damp_init=damp_init,
damp_dynamic=damp_dynamic,
damp_dynamic_factor=damp_dynamic_factor,
damp_soft_start=damp_soft_start,
damp_generations=damp_generations,
damp_diagonal_offset=damp_diagonal_offset,
scf_mode=scf_mode,
scp_mode=scp_mode,
x_atol=x_atol,
x_atol_max=x_atol_max,
f_atol=f_atol,
force_convergence=force_convergence,
batch_mode=batch_mode,
# SCF: Fermi
fermi_etemp=fermi_etemp,
fermi_maxiter=fermi_maxiter,
fermi_thresh=fermi_thresh,
fermi_partition=fermi_partition,
# SCF: PyTorch
device=device,
dtype=dtype,
)
# compatibility checks (only need to be skipped for some tests)
if skip_compat_checks is False:
if (
self.method == labels.GFN2_XTB
and self.ints.driver != labels.INTDRIVER_LIBCINT
):
raise RuntimeError(
"Multipole integrals not available in PyTorch integral "
"drivers. Use `libcint` as backend."
)
[docs]
@classmethod
def from_args(cls, args: Namespace) -> Self:
"""
Create a configuration from command-line arguments.
Parameters
----------
args : Namespace
The parsed command-line arguments.
Returns
-------
Self
The configuration object.
"""
return cls(
# general
file=args.file,
strict=args.strict,
exclude=args.exclude,
method=args.method,
grad=args.grad,
# integrals
int_cutoff=args.int_cutoff,
int_driver=args.int_driver,
int_level=args.int_level,
int_uplo=args.int_uplo,
# PyTorch
anomaly=args.detect_anomaly,
device=args.device,
dtype=args.dtype,
# SCF
maxiter=args.maxiter,
mixer=args.mixer,
damp=args.damp,
guess=args.guess,
scf_mode=args.scf_mode,
scp_mode=args.scp_mode,
x_atol=args.xtol,
f_atol=args.ftol,
force_convergence=args.force_convergence,
# SCF: Fermi
fermi_etemp=args.fermi_etemp,
fermi_maxiter=args.fermi_maxiter,
fermi_thresh=args.fermi_thresh,
fermi_partition=args.fermi_partition,
# Cache
cache_enabled=args.cache_enabled,
cache_hcore=args.cache_hcore,
cache_overlap=args.cache_overlap,
cache_dipole=args.cache_dipole,
cache_quadrupole=args.cache_quadrupole,
cache_coefficients=args.cache_coefficients,
cache_density=args.cache_density,
cache_fock=args.cache_fock,
cache_mo_energies=args.cache_mo_energies,
cache_occupation=args.cache_occupation,
cache_potential=args.cache_potential,
)
[docs]
@classmethod
def from_json(cls, path: PathLike) -> Self:
"""
Create a configuration from a JSON file.
Parameters
----------
path : PathLike
The path to the JSON file.
Returns
-------
Self
The configuration object.
Raises
------
FileNotFoundError
If the file does not exist.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File '{path}' does not exist.")
# pylint: disable=import-outside-toplevel
import json
with open(path, encoding="utf-8") as json_file:
cfg = json.loads(json_file.read())
return cls.from_dict(cfg)
[docs]
@classmethod
def from_dict(cls, cfg: dict[str, Any]) -> Self:
"""
Create a configuration from a dictionary.
Parameters
----------
cfg : dict[str, Any]
The configuration dictionary.
Returns
-------
Self
The configuration object.
"""
# TODO: More sophisticated validation
return cls(**cfg)
@property
def batch_mode(self) -> int:
"""
Whether multiple systems or a single one are handled.
The following batch modes are available:
- 0: Single system
- 1: Multiple systems with padding
- 2: Multiple systems with no padding (conformer ensemble)
Returns
-------
int
The batch mode.
"""
return self._batch_mode
@batch_mode.setter
def batch_mode(self, value: int) -> None:
"""
Set the batch mode.
Parameters
----------
value : int
The batch mode.
Raises
------
ValueError
If the batch mode is invalid.
"""
if value not in (0, 1, 2):
raise ValueError(
f"Invalid batch mode '{value}'. Must be one of [0, 1, 2]."
)
self._batch_mode = value
self.scf.batch_mode = value
[docs]
def info(self) -> dict[str, dict[str, Any]]:
"""
Return a dictionary with the configuration information.
Returns
-------
dict[str, dict[str, Any]]
The configuration information.
"""
return {
"Calculation Configuration": {
"Program Call": " ".join(sys.argv),
"Input File(s)": self.file,
"Method": labels.GFN_XTB_MAP[self.method],
"Excluded": False if len(self.exclude) == 0 else self.exclude,
"Gradient": self.grad,
"Integral driver": labels.INTDRIVER_MAP[self.ints.driver],
"FP accuracy": str(self.dtype),
"Device": str(self.device),
},
**self.scf.info(),
}
[docs]
def to_json(self, path: PathLike | None = None) -> str:
"""
Serialize the configuration to a JSON-formatted string.
Returns:
str: A JSON-formatted string representing the configuration.
"""
# pylint: disable=import-outside-toplevel
import json
config_info = self.info()
def serialize(value):
if isinstance(value, torch.device) or isinstance(
value, torch.dtype
):
return str(value)
elif isinstance(value, list):
# Recursively serialize lists
return [serialize(v) for v in value]
elif isinstance(value, dict):
# Recursively serialize dicts
return {k: serialize(v) for k, v in value.items()}
else:
return value
# Serialize the entire configuration info to JSON
serialized_info = {k: serialize(v) for k, v in config_info.items()}
# Convert the dictionary to a JSON string
json_string = json.dumps(serialized_info, indent=4)
if path is not None:
path = Path(path)
if path.exists():
path.unlink()
with open(path, "w", encoding="utf-8") as json_file:
json_file.write(json_string)
return json_string
def __str__(self) -> str: # pragma: no cover
info = self.info()["SCF Options"]
info_str = ", ".join(f"{key}={value}" for key, value in info.items())
return f"{self.__class__.__name__}({info_str})"
def __repr__(self) -> str: # pragma: no cover
return str(self)