Source code for monitorch.preprocessor.output.output_activation
from typing import Any
from torch import Tensor, is_grad_enabled, no_grad
from torch import abs as tabs
from torch import float32 as tfloat32
from monitorch.numerical import RunningMeanVar, reduce_activation_to_activation_rates
from monitorch.preprocessor.abstract.abstract_forward_preprocessor import AbstractForwardPreprocessor
[docs]
class OutputActivation(AbstractForwardPreprocessor):
"""
Preprocessor to record activations of outputs.
We say that a neuron or a channel is activated if the output is non-zero
(information is propagated forward). If neuron is not activated for all samples in a batch,
we say it is dead. Death rate is a proportion of dead neurons against layer size.
Parameters
----------
death : bool
Indicator if death rate is to be collected.
inplace : bool
Indicator if :class:`RunningMeanVar` or ``list`` should be used for aggregation.
record_eval : bool
Indicator if outputs during evaluation must be preprocessed.
evaluation_from_grad : bool
Flag indicating if evaluation passes should be considered from gradient or modele.training
eps : float
Numerical constant under which value is regarded as a zero.
channel_last : bool
If ``True``, expects data in ``[batch, seq_len, ..., features]`` format where the feature/channel
dimension is last (e.g. transformer outputs). If ``False`` (default), expects PyTorch's standard
``[batch, features, spatial_dims, ...]`` format.
"""
def __init__(self, death: bool, inplace: bool, record_eval: bool, evaluation_from_grad: bool, eps: float = 1e-8, channel_last: bool = False):
self._death = death
self._value = {} # Either name -> activation or name -> (activation, death_tensor)
self._thresholds: dict[str, tuple[float, float]] = {}
self._agg_class = RunningMeanVar if inplace else list
self._record_eval = record_eval
self._is_train = (lambda m: is_grad_enabled()) if evaluation_from_grad else (lambda m: m.training)
self._eps = eps
self._channel_last = channel_last
[docs]
def process_fw(self, name: str, module, layer_input, layer_output) -> None:
"""
Computes activation from layer output.
Flattens spatial dimensions, computes activations and saves each sample.
Computes death rate if ``death=True`` was set.
Parameters
----------
name : str
Name of the module which outputs are processed.
module : torch.nn.Module
Module object, its outputs are processed.
layer_input : torch.Tensor
Should be input of layer, but it is ignored in this method.
layer_output : torch.Tensor
Outputs to compute activations from.
"""
if not (self._record_eval or self._is_train(module)):
return
if name not in self._value:
if self._death:
self._value[name] = (self._agg_class(), self._agg_class())
else:
self._value[name] = self._agg_class()
new_activation_tensor: Tensor
new_activation_rate: Tensor
with no_grad():
new_activation_tensor = tabs(layer_output) > self._eps
if self._channel_last:
new_activation_tensor = new_activation_tensor.movedim(-1, 1)
new_activation_rate = reduce_activation_to_activation_rates(new_activation_tensor, batch=True)
if self._death:
activations, death_rates = self._value[name]
death_rates.append(new_activation_rate.eq(0).mean(dtype=tfloat32))
activations.append(new_activation_rate.mean(dtype=tfloat32))
else:
activations = self._value[name]
activations.append(new_activation_rate.mean(dtype=tfloat32))
@property
def value(self) -> dict[str, Any]:
"""See base class."""
return self._value
[docs]
def reset(self) -> None:
"""See base class."""
self._value = {}