Source code for dxtb._src.param.module.param

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Parametrization: Differentiable Parametrization
===============================================

This module provides a differentiable representation of the
extended tight-binding parametrization using PyTorch.

The :class:`.ParamModule` class automatically converts a Pydantic
model into a hierarchical :class:`~torch.nn.Module` tree.
"""

from __future__ import annotations

import torch
from torch import nn

from dxtb._src.constants.defaults import DEFAULT_BASIS_INT
from dxtb._src.typing import DD, Any, PathLike
from dxtb._src.utils import is_float, is_float_list, is_int_list, is_integer

from ..base import Param
from .types import NonNumericValue, ParameterModule
from .utils import ParamElementsPairsMixin

__all__ = ["NonNumericValue", "ParamModule", "ParameterModule"]


[docs] class ParamModule(nn.Module, ParamElementsPairsMixin): """ Automatically converts a Pydantic model into a hierarchical :class:`~torch.nn.Module` tree. All numeric values (or lists of numbers) are wrapped in :class:`ParameterModule` instances (with :attr:`requires_grad` set to False by default) so that they are registered as parameters. Nonnumeric values are wrapped in :class:`NonNumericValue`. """ def __init__( self, par: Param, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: """ Parameters ---------- par : Param The validated Pydantic model. device : torch.device | None, optional Device of the tensors. Defaults to ``None``. dtype : torch.dtype | None, optional Data type of the tensors. If ``None``, the default data type from ``get_default_dtype`` is used. Defaults to ``None``. """ super().__init__() # Recursively convert the dictionary into a parameter tree. self.parameter_tree = _convert(par.clean_model_dump(), device, dtype) # Dummy tensor to get the device and dtype. self.register_buffer( "dummy", torch.empty(0, device=device, dtype=dtype) ) @property def device(self) -> torch.device: """Returns the device where the first parameter/buffer is located.""" return self.dummy.device # type: ignore @property def dtype(self) -> torch.dtype: """Returns the data type of the parameters.""" return self.dummy.dtype # type: ignore @property def dd(self) -> DD: """Returns the dictionary of device and data type.""" return {"device": self.device, "dtype": self.dtype}
[docs] def forward(self) -> nn.Module: """ Returns the internal parameter tree. This tree contains all numeric values wrapped in :class:`ParameterModule` and nonnumeric values wrapped in :class:`NonNumericValue`. """ return self.parameter_tree
# Conversion
[docs] def to_dict(self) -> dict[str, Any]: """ Revert the differentiable parameter tree into a plain Python dictionary. This method recursively unwraps all parameter wrappers (e.g. :class:`ParameterModule` and :class:`NonNumericValue`) and converts any tensors to Python scalars (if zero-dimensional) or lists (if higher-dimensional). The resulting dictionary mirrors the original model structure, and can be serialized to different formats without loss of information. Returns ------- dict[str, Any] A dictionary representation of the differentiable parameters. """ return _revert(self.parameter_tree)
[docs] def to_pydantic(self) -> Param: """ Converts the parameter tree back to a Pydantic model. Returns ------- Param A Pydantic model representation of the differentiable parameters. """ return Param(**self.to_dict())
# Convert to file formats
[docs] def to_file(self, filepath: PathLike, **kwargs) -> None: """ Save the parametrization to a file. The file format is determined by the file extension. Supported formats are JSON, TOML, and YAML. Parameters ---------- filepath : PathLike The file path to save the parametrization data. Raises ------ ValueError If the file format is not supported. """ self.to_pydantic().to_file(filepath, **kwargs)
[docs] def to_json_file(self, filename: PathLike, **kwargs: Any) -> None: """ Converts the parameter tree to a JSON file. Parameters ---------- filename : PathLike The name of the JSON file to save the parameters. kwargs : dict Additional keyword arguments for the dump function of the JSON writer. """ self.to_pydantic().to_json_file(filename, **kwargs)
[docs] def to_toml_file(self, filename: PathLike, **kwargs: Any) -> None: """ Converts the parameter tree to a TOML file. Parameters ---------- filename : PathLike The name of the TOML file to save the parameters. kwargs : dict Additional keyword arguments for the dump function of the TOML writer. """ self.to_pydantic().to_toml_file(filename, **kwargs)
[docs] def to_yaml_file(self, filename: PathLike, **kwargs: Any) -> None: """ Converts the parameter tree to a YAML file. Parameters ---------- filename : PathLike The name of the YAML file to save the parameters. kwargs : dict Additional keyword arguments for the dump function of the YAML writer. """ self.to_pydantic().to_yaml_file(filename, **kwargs)
# Pretty-printing def __str__(self) -> str: # pragma: no cover return f"{self.__class__.__name__}({self.parameter_tree})" def __repr__(self) -> str: # pragma: no cover return str(self)
def _convert( value: Any, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> nn.Module: """ Recursively converts an input value into a :class:`~torch.nn.Module` tree. * If the input is a dictionary, returns a :class:`~torch.nn.ModuleDict` with values recursively converted. * If the input is a list of numbers, converts it into a tensor wrapped in a :class:`ParameterModule`. * For a nonnumeric list, each element is converted; if all are modules, they are put into a :class:`~torch.nn.ModuleList`. Otherwise the list is wrapped in a :class:`NonNumericValue`. * If the input is a numeric value, converts it into a tensor wrapped in a :class:`ParameterModule`. * All nonnumeric values are wrapped in a :class:`NonNumericValue`. Parameters ---------- value : Any The value to convert. device : torch.device | None, optional Device of the tensors. If ``None``, the device of `freqs` is used. Defaults to ``None``. dtype : torch.dtype | None, optional Data type of the tensors. If ``None``, the data type of `freqs` is used. Defaults to ``None``. Returns ------- nn.Module The converted value as a :class:`~torch.nn.Module`. """ # If value is a dictionary, convert each key recursively. if isinstance(value, dict): out = nn.ModuleDict() for key, v in value.items(): out[key] = _convert(v, device, dtype) return out # If value is a list... if isinstance(value, list): # If it's a list of numbers, convert it directly. _is_int = is_int_list(value) _is_float = is_float_list(value) if value and (_is_int or _is_float): _dtype = dtype if _is_float else DEFAULT_BASIS_INT tensor_value = torch.tensor(value, device=device, dtype=_dtype) return ParameterModule(tensor_value) # Otherwise, process each item recursively. converted: list[nn.Module] = [] all_modules = True for item in value: conv_item = _convert(item, device, dtype) if not isinstance(conv_item, nn.Module): all_modules = False converted.append(conv_item) if all_modules: return nn.ModuleList(converted) return NonNumericValue(converted) # If value is a numeric leaf, convert it to a ParameterModule. if is_float(value): tensor_value = torch.tensor(value, device=device, dtype=dtype) return ParameterModule(tensor_value) if is_integer(value): tensor_value = torch.tensor(value, device=device, dtype=torch.int8) return ParameterModule(tensor_value) # For any other type, wrap in NonNumericValue. return NonNumericValue(value) def _revert(module: Any) -> Any: """ Recursively reverts a differentiable parameter tree into a plain Python dictionary. This function unwraps all parameter wrappers (e.g. :class:`ParameterModule` and :class:`NonNumericValue`) and converts any tensors to Python scalars (if zero-dimensional) or lists (if higher-dimensional). The resulting dictionary mirrors the original model structure, and can be serialized to different formats without loss of information. Parameters ---------- module : Any The input value to revert. Returns ------- Any The value reverted by one level. """ # If module is a ModuleDict, convert it to a dict. if isinstance(module, nn.ModuleDict): return {k: _revert(child) for k, child in module.items()} # If module is a ModuleList, convert it to a list. if isinstance(module, nn.ModuleList): return [_revert(child) for child in module] # If module is a ParameterModule, return its underlying tensor # as a scalar or list. if isinstance(module, ParameterModule): if module.param.dim() == 0: return module.param.item() return module.param.tolist() # If module is NonNumericValue, try to recursively revert its value. if isinstance(module, NonNumericValue): value = module.value # If underlying value is nn.Module, process it recursively. if isinstance( value, ( nn.ModuleDict, nn.ModuleList, ParameterModule, NonNumericValue, ), ): return _revert(value) return value # If it is tensor (unexpected), convert it. if torch.is_tensor(module): if module.dim() == 0: return module.item() return module.tolist() # Otherwise, assume it is already a plain Python type. return module