Source code for dxtb._src.timing.timer

# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Definition of a timer class that can contain multiple timers.

For developers
--------------
Remember to manually reset the timer in tests that are supposed to fail.
Otherwise, `timer.stop()` may not be called and the next test tries to start
the same timer again, which will throw a (confusing) `TimerError`.
For an example, see `test/test_calculator/test_general.py::test_fail`.
"""

from __future__ import annotations

import time

__all__ = ["timer", "create_timer", "kill_timer"]


class TimerError(Exception):
    """
    A custom exception used to report errors in use of Timer class.
    """


def _sync() -> None:
    """
    Wait for all kernels in all streams on a CUDA device to complete.
    """
    import torch

    if torch.cuda.is_available():
        torch.cuda.synchronize()


class _Timers:
    """
    Collection of Timers.
    Upon instantiation, a timer with the label 'total' is started.
    """

    class _Timer:
        """Instance of a Timer."""

        label: str | None
        """Name of the Timer."""

        parent: _Timers
        """Parent Timer collection."""

        _start_time: float | None
        """Time when the timer was started. Should not be accessed directly."""

        elapsed_time: float
        """Elapsed time in seconds."""

        def __init__(
            self,
            parent: _Timers,
            label: str | None = None,
            cuda_sync: bool = False,
        ) -> None:
            self.parent = parent
            self.label = label
            self._start_time = None
            self.elapsed_time = 0.0
            self._cuda_sync = cuda_sync

        @property
        def cuda_sync(self) -> bool:
            """
            Check if CUDA synchronization is enabled.

            Returns
            -------
            bool
                Whether CUDA synchronization is enabled (``True``) or not
                (``False``).
            """
            return self._cuda_sync

        @cuda_sync.setter
        def cuda_sync(self, value: bool) -> None:
            """
            Enable or disable CUDA synchronization.

            Parameters
            ----------
            value : bool
                Whether to enable (``True``) or disable (``False``) CUDA
                synchronization.
            """
            self._cuda_sync = value

        def start(self) -> None:
            """
            Start a new timer.

            Raises
            ------
            TimerError
                If timer is already running.
            """
            if not self.parent.enabled:
                return

            if self._start_time is not None:
                raise TimerError(
                    f"Timer '{self.label}' is running. Use `.stop()` to stop it."
                )

            if self.cuda_sync is True:
                _sync()

            self._start_time = time.perf_counter()

        def stop(self) -> float:
            """
            Stop the timer.

            Returns
            -------
            float
                Elapsed time in seconds.

            Raises
            ------
            TimerError
                If timer is not running.
            """
            if not self.parent.enabled:
                return 0.0

            if self._start_time is None:
                raise TimerError(
                    f"Timer '{self.label}' is not running. Use .start() to "
                    "start it."
                )

            if self.cuda_sync is True:
                _sync()

            self.elapsed_time += time.perf_counter() - self._start_time
            self._start_time = None

            return self.elapsed_time

        def is_running(self) -> bool:
            """
            Check if the timer is running.

            Returns
            -------
            bool
                Whether the timer currently runs (``True``) or not (``False``).
            """
            if self._start_time is not None and self.elapsed_time == 0.0:
                return True
            return False

    timers: dict[str, _Timer]
    """Dictionary of timers."""

    label: str | None
    """Name for the Timer collection."""

    def __init__(
        self,
        label: str | None = None,
        autostart: bool = False,
        cuda_sync: bool = False,
        only_parents: bool = False,
    ) -> None:
        self.label = label
        self.timers = {}
        self._enabled: bool = True
        self._subtimer_parent_map: dict[str, str] = {}
        self._autostart = autostart
        self._cuda_sync = cuda_sync
        self._only_parents = only_parents

        if self._autostart is True:
            self.reset()

    def enable(self) -> None:
        """
        Enable all timers in the collection.
        """
        self._enabled = True

    def disable(self) -> None:
        """
        Disable and reset all timers in the collection.
        """
        self._enabled = False

    @property
    def enabled(self) -> bool:
        """
        Check if the timer is enabled.

        Returns
        -------
        bool
            Whether the timer is enabled (``True``) or not (``False``).
        """
        return self._enabled

    @property
    def cuda_sync(self) -> bool:
        """
        Check if CUDA synchronization is enabled.

        Returns
        -------
        bool
            Whether CUDA synchronization is enabled (``True``) or not
            (``False``).
        """
        return self._cuda_sync

    @cuda_sync.setter
    def cuda_sync(self, value: bool) -> None:
        """
        Enable or disable CUDA synchronization.

        Parameters
        ----------
        value : bool
            Whether to enable (``True``) or disable (``False``) CUDA
            synchronization.
        """
        self._cuda_sync = value

        for t in self.timers.values():
            t.cuda_sync = value

    @property
    def only_parents(self) -> bool:
        """
        Check if only parent timers are enabled.

        Returns
        -------
        bool
            Whether only parent timers are enabled (``True``) or not
            (``False``).
        """
        return self._only_parents

    @only_parents.setter
    def only_parents(self, value: bool) -> None:
        """
        Enable or disable only parent timers.

        Parameters
        ----------
        value : bool
            Whether to enable (``True``) or disable (``False``) only parent
            timers.
        """
        self._only_parents = value

    def start(
        self, uid: str, label: str | None = None, parent_uid: str | None = None
    ) -> None:
        """
        Create a new timer or start an existing timer with `uid`.

        Parameters
        ----------
        uid : str
            ID of the timer.
        label : str | None
            Name of the timer (used for printing). Defaults to ``None``.
            If no `label` is given, the `uid` is used.
        """
        if not self._enabled:
            return

        if self.only_parents is True and parent_uid is not None:
            return

        if uid in self.timers:
            self.timers[uid].start()
            return

        t = self._Timer(self, uid if label is None else label, self.cuda_sync)
        t.start()

        self.timers[uid] = t

        if parent_uid is not None:
            if parent_uid in self.timers:
                self._subtimer_parent_map[uid] = parent_uid

    def stop(self, uid: str) -> float:
        """
        Stop the timer

        Parameters
        ----------
        uid : str
            Unique ID of the timer.

        Returns
        -------
        float
            Elapsed time in seconds.

        Raises
        ------
        TimerError
            If timer dubbed `uid` does not exist.
        """
        if not self.enabled:
            return 0.0

        if uid not in self.timers:
            # If sub timers are disabled, some timers will not exist. So,
            # instead of raising an error, we return just 0.0.
            if self.only_parents is True:
                return 0.0

            raise TimerError(f"Timer '{uid}' does not exist.")

        t = self.timers[uid]
        elapsed_time = t.stop()

        return elapsed_time

    def stop_all(self) -> None:
        """Stop all running timers."""
        for t in self.timers.values():
            if t.is_running():
                t.stop()

    def reset(self) -> None:
        """
        Reset all timers in the collection.

        This method reinitializes the timers dictionary and restarts the
        'total' timer.
        """
        self.timers = {}
        self.start("total")

    def kill(self) -> None:
        """
        Disable, reset and stop all timers.
        """
        self.disable()
        self.reset()
        self.stop_all()

        self.timers.clear()
        self._subtimer_parent_map.clear()

    def get_time(self, uid: str) -> float:
        """
        Get the elapsed time of a timer.

        Parameters
        ----------
        uid : str
            Unique ID of the timer.

        Returns
        -------
        float
            Elapsed time in seconds.

        Raises
        ------
        TimerError
            If timer dubbed `uid` does not exist.
        """
        if not self.enabled:
            return 0.0

        if uid not in self.timers:
            raise TimerError(f"Timer '{uid}' does not exist.")

        return self.timers[uid].elapsed_time

    def get_times(self) -> dict[str, dict[str, float]]:
        """
        Get the elapsed times of all timers,

        Returns
        -------
        dict[str, float]
            Dictionary of timer IDs and elapsed times.
        """
        if self.timers["total"].is_running():
            self.timers["total"].stop()

        KEY = "value"
        times = {}

        # Initialize all parent timers in the times dictionary
        for k in self.timers:
            if k not in self._subtimer_parent_map:
                times[k] = {KEY: None, "sub": {}}

        # Add times for all timers, categorizing based on the parent map
        for uid, t in self.timers.items():
            if uid in self._subtimer_parent_map:
                parent = self._subtimer_parent_map[uid]
                times[parent]["sub"][uid] = t.elapsed_time
            else:
                times[uid][KEY] = t.elapsed_time

        total_time = times["total"][KEY]
        for main_timer, details in times.items():
            if main_timer == "total":
                continue

            # Calculate the percentage of the total time for main timers
            main_time = details[KEY]
            percentage_of_total = (main_time / total_time) * 100
            details["percentage"] = f"{percentage_of_total:.2f}"

            # Calculate the percentage relative to the parent timer for sub
            if details["sub"]:
                for subtimer, sub_time in details["sub"].items():
                    percentage_of_parent = (sub_time / main_time) * 100
                    details["sub"][subtimer] = {
                        KEY: sub_time,
                        "percentage": f"{percentage_of_parent:.2f}",
                    }

        return times

    def print(self, v: int = 5, precision: int = 3) -> None:  # pragma: no cover
        """Print the elapsed times of all timers in a table."""
        if not self._enabled:
            return

        if self.timers["total"].is_running():
            self.timers["total"].stop()

        # pylint: disable=import-outside-toplevel
        from ..io import OutputHandler

        OutputHandler.write_table(
            self.get_times(),
            title="Timings",
            columns=["Objective", "Time (s)", "% Total"],
            v=v,
            precision=precision,
        )

    def __str__(self) -> str:  # pragma: no cover
        """Return a string representation of the :class:`._Timers` instance."""
        timers_repr = ", ".join(
            f"'{label}': {timer.elapsed_time:.3f}s"
            for label, timer in self.timers.items()
        )

        return (
            f"{self.__class__.__name__}("
            f"label={self.label!r}, "
            f"enabled={self._enabled}, "
            f"cuda_sync={self._cuda_sync}, "
            f"only_parents={self._only_parents}, "
            f"timers={{{timers_repr}}}"
            f")"
        )

    def __repr__(self) -> str:  # pragma: no cover
        """Return a string representation of the :class:`._Timers` instance."""
        return str(self)


def create_timer(autostart: bool = True, cuda_sync: bool = False) -> _Timers:
    """
    Create a new timer instance.

    Parameters
    ----------
    autostart : bool, optional
        Whether to start the total timer automatically. Defaults to ``True``.
    cuda_sync : bool, optional
        Whether to call :func:`torch.cuda.synchronize` after CUDA operations.
        Defaults to ``False``.

    Returns
    -------
    _Timers
        Instance of the timer class.

    Note
    ----
    Delete the timer instance with :func:`.kill_timer` when it is no longer
    needed or throws errors when reusing it.
    """
    global timer
    timer = _Timers(autostart=autostart, cuda_sync=cuda_sync)
    return timer


[docs] def kill_timer() -> None: """Delete the global timer instance.""" global timer if "timer" not in globals(): raise TimerError( "Cannot delete timer instance; timer was never initialized." ) timer.kill() del timer
from os import getenv timer = create_timer( autostart=( getenv("DXTB_TIMER_AUTOSTART", "True").casefold() in ("true", "1", "yes") ), cuda_sync=False, ) """Global instance of the timer class."""