Source code for monitorch.preprocessor.ExplicitCall
from typing import Any
from monitorch.numerical import RunningMeanVar
from .abstract.abstract_preprocessor import AbstractPreprocessor
[docs]
class ExplicitCall(AbstractPreprocessor):
"""
Class for accumulating data passed by explicit call.
Object of ``ExplicitCall`` class are provided by :class:`PyTorchInspector`
to lenses as a foreign preprocessor. ``ExplicitCall`` implements methods to interact directly with its data.
Its primary usage is to track loss and other performance metrics for :class:`LossMetrics` lens.
Parameters
----------
train_loss_str : str
String to save training loss under.
non_train_loss_str : str
String to save development, validation or test loss under.
Attributes
----------
state : dict[str, Any]
Aggregated data indexed by their names.
train_loss_str : str
String to save training loss under.
non_train_loss_str : str
String to save non-training loss under.
"""
def __init__(self, train_loss_str, non_train_loss_str):
self.state: dict[str, Any] = {}
self.train_loss_str = train_loss_str
self.non_train_loss_str = non_train_loss_str
[docs]
def push_memory(self, name: str, value) -> None:
"""
Appends value to container under name and creates a list if there is none.
Parameters
----------
name : str
Name under which the value will be saved.
value
The value to be saved.
"""
self.state.setdefault(name, []).append(value)
[docs]
def push_running(self, name: str, value: float) -> None:
"""
Appends value to container under name and creates a :class:`RunningMeanVar` if there is none.
Parameters
----------
name : str
Name under which the value will be saved.
value
The value to be saved.
"""
self.state.setdefault(name, RunningMeanVar()).append(value)
[docs]
def push_loss(self, value: float, *, train: bool, running: bool = True):
"""
A utility function to save loss.
A shorthand to choose whether loss is running and what name to push it under.
Parameters
----------
value : float
Value of loss to be saved.
train : bool
Whether loss should be saved under :attr:`train_loss_str` or :attr:`non_train_loss_str`
running : bool
Indicates if :meth:`push_running` or :meth:`push_memory` should used.
"""
name = self.train_loss_str if train else self.non_train_loss_str
if running:
self.push_running(name, value)
else:
self.push_memory(name, value)
@property
def value(self) -> dict[str, Any]:
return self.state
[docs]
def reset(self) -> None:
self.state = {}