Source code for monitorch.preprocessor.output.output_norm
from math import sqrt
from typing import Any
from torch import is_grad_enabled, no_grad
from torch.linalg import vector_norm
from monitorch.numerical import RunningMeanVar
from monitorch.preprocessor.abstract.abstract_forward_preprocessor import AbstractForwardPreprocessor
[docs]
class OutputNorm(AbstractForwardPreprocessor):
"""
Preprocessor to compute norms of outputs.
Flattens spatial and channel/neuron dimensions of output,
computes L2 norm or RMS (if normalized) of flattened vectors and takes mean over a batch.
Parameters
----------
normalize : bool
Indicator if output norm should be normalized by square root of number of elements in single sample output.
inplace : bool
Indicator if :class:`RunningMeanVar` or ``list`` should be used for aggregation.
record_eval : bool
Indicator if outputs during evaluation must be preprocessed.
evaluation_from_grad : bool
Flag indicating if evaluation passes should be considered from gradient or modele.training
channel_last : bool
If ``True``, expects data in ``[batch, seq_len, ..., features]`` format where the feature/channel
dimension is last (e.g. transformer outputs). If ``False`` (default), expects PyTorch's standard
``[batch, features, spatial_dims, ...]`` format. The norm computation is equivalent in both cases
since all non-batch dimensions are flattened before computing the L2 norm.
"""
def __init__(self, normalize: bool, inplace: bool, record_eval: bool, evaluation_from_grad: bool, channel_last: bool = False):
self._normalize = normalize
self._value = {}
self._agg_class = RunningMeanVar if inplace else list
self._record_eval = record_eval
self._is_train = (lambda m: is_grad_enabled()) if evaluation_from_grad else (lambda m: m.training)
self._channel_last = channel_last
[docs]
def process_fw(self, name: str, module, layer_input, layer_output):
"""
Computes mean output norm.
Flattens spatial and channel dimensions, computes (normalized) norm of individual samples
and saves their average.
Parameters
----------
name : str
Name of the module which outputs are processed.
module : torch.nn.Module
Module object, its outputs are processed.
layer_input : torch.Tensor
Should be input of layer, but it is ignored in this method.
layer_output : torch.Tensor
Outputs to compute norm from.
"""
if not (self._record_eval or self._is_train(module)):
return
norm_container = self._value.setdefault(name, self._agg_class())
with no_grad():
norm_mean = vector_norm(layer_output.flatten(1, -1), dim=-1).mean()
if self._normalize:
norm_mean /= sqrt(layer_output[0].numel())
norm_container.append(norm_mean)
@property
def value(self) -> dict[str, Any]:
return self._value
[docs]
def reset(self) -> None:
self._value = {}