Source code for monitorch.preprocessor.output.loss_module

from typing import Any

from torch import is_grad_enabled

from monitorch.numerical import RunningMeanVar
from monitorch.preprocessor.abstract import AbstractForwardPreprocessor


[docs] class LossModule(AbstractForwardPreprocessor): """ Module to record single value loss. Aggregates loss from loss modules (i.e. ``torch.nn.MSELoss`` or ``torch.nn.NLLLoss``). It can be accessed later. Parameters ---------- inplace : bool Indicator if :class:`RunningMeanVar` or ``list`` should be used for aggregation. evaluation_from_grad : bool Flag indicating if evaluation passes should be considered from gradient or modele.training """ def __init__(self, inplace: bool, evaluation_from_grad: bool): self._value = {} self._train_str_loss = '' self._non_train_str_loss = '' self._agg_class = RunningMeanVar if inplace else list self._is_train = (lambda m: is_grad_enabled()) if evaluation_from_grad else (lambda m: m.training)
[docs] def set_loss_strs(self, train_loss_str: str, non_train_loss_str: str): """ Defines names for training and test/validation/development loss. Given strings will be used in :meth:`value` for indexing. Parameters ---------- train_loss_str : str String used for training loss. non_train_loss_str : str String used for test/validation/development loss. """ self._value = {train_loss_str: self._agg_class(), non_train_loss_str: self._agg_class()} self._train_str_loss = train_loss_str self._non_train_str_loss = non_train_loss_str
[docs] def process_fw(self, name: str, module, layer_input, layer_output): """ Saves loss passed as layer output. Parameters ---------- name : str Name of the module. Ignored. module : torch.nn.Module The module object. Ignored. layer_input : torch.Tensor Input to loss module. Ignored. layer_output : torch.Tensor Loss tensor. Must have single element. Raises ------ AttributeError If layer_output has none or more than one elements. """ if layer_output.numel() != 1: raise AttributeError('Only single item loss can be preprocessed') if self._is_train(module): self._value[self._train_str_loss].append(layer_output.item()) else: self._value[self._non_train_str_loss].append(layer_output.item())
@property def value(self) -> dict[str, Any]: """See base class.""" return self._value
[docs] def reset(self) -> None: """See base class.""" self._value = {self._train_str_loss: self._agg_class(), self._non_train_str_loss: self._agg_class()}