Source code for kiui.timer
import os
import torch
from functools import wraps
[docs]
class sync_timer:
"""Synchronized timer to count the inference time of `nn.Module.forward` or else.
:class:`sync_timer` can be used as a context manager or a decorator.
Example as context manager:
.. code-block:: python
with timer('name'):
run()
Example as decorator:
.. code-block:: python
@timer('name')
def run():
pass
Args:
name (str, optional): name of the timer. Defaults to None.
flag_env (str, optional): environment variable to check if logging is enabled. Defaults to "TIMER".
logger_func (Callable, optional): function to log the result. Defaults to ``print``.
Note:
Set environment variable ``$flag_env`` to ``1`` to enable logging! default is ``TIMER=1``.
"""
[docs]
def __init__(self, name=None, flag_env="TIMER", logger_func=print):
self.name = name
self.flag_env = flag_env
self.logger_func = logger_func
def __enter__(self):
if os.environ.get(self.flag_env, "0") == "1":
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
self.start.record()
def __exit__(self, exc_type, exc_value, exc_tb):
if os.environ.get(self.flag_env, "0") == "1":
self.end.record()
torch.cuda.synchronize()
delta_time = self.start.elapsed_time(self.end)
if self.name is not None:
self.logger_func(f"{self.name} takes {delta_time/1000:.3f} s")
def __call__(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
with self:
result = func(*args, **kwargs)
return result
return wrapper