Source code for monitorch.preprocessor.gradient.gradient_activation

from typing import Any

from torch import abs as tabs
from torch import float32 as tfloat32
from torch import no_grad

from monitorch.numerical import RunningMeanVar, reduce_activation_to_activation_rates
from monitorch.preprocessor.abstract.abstract_tensor_preprocessor import AbstractTensorPreprocessor


[docs] class GradientActivation(AbstractTensorPreprocessor): """ Preprocessor class to compute gradient activaitions and death. We define a neuron to be active if it has non-zero gradient at any datapoint in a batch iteration, it is dead otherwise. This preprocessor calcualtes death rate and activations over an epoch. Death rate is a proportion of dead neurons in each batch. It can be further aggregated into mean or median accross all batch iterations in an epoch. Parameters ---------- death : bool Flag indicating if death rate should be computed. inplace : bool Flag indicating whether to collect data inplace using :class:`RunningMeanVar` or to stack them into a list. eps : float Numerical constant under which value is regarded as a zero. """ def __init__(self, death: bool, inplace: bool, eps: float = 1e-8): self._death = death self._value = {} self._agg_class = RunningMeanVar if inplace else list self._eps = eps
[docs] def process_tensor(self, name: str, grad): """ Computes activation and death rate on a gradient. Transforms gradient into a boolean mask, applies :func:`reduce_activation_to_activation_rates`. Activation rates are saved and used to compute death rate. Parameters ---------- name : str Name of a source of gradient. grad : torch.Tensor Gradient tensor to compute activations from. """ if name not in self._value: if self._death: self._value[name] = (self._agg_class(), self._agg_class()) else: self._value[name] = self._agg_class() with no_grad(): new_activation_tensor = tabs(grad) > self._eps new_activation_rates = reduce_activation_to_activation_rates(new_activation_tensor, batch=False) if self._death: activations, death_rates = self._value[name] death_rates.append(new_activation_rates.eq(0.0).mean(dtype=tfloat32)) activations.append(new_activation_rates.mean(dtype=tfloat32)) else: activations = self._value[name] activations.append(new_activation_rates.mean(dtype=tfloat32))
@property def value(self) -> dict[str, Any]: """See base class.""" return self._value
[docs] def reset(self) -> None: """See base class.""" self._value = {}