OutputActivation#

class monitorch.lens.OutputActivation(inplace: bool = True, record_eval: bool = False, evaluation_from_grad: bool = False, activation: bool = True, dropout: bool = True, include: Iterable[type[Module]] = (), exclude: Iterable[type[Module]] = (), channel_last: bool = False, warning_plot: bool = True, activation_aggregation: str = 'mean', death_aggregation: str = 'mean')[source]#

Bases: AbstractLens

A lens to inspect neuron activation through output.

Neuron is active if it yields non-zero value, if neuron is inactive for the whole batch iteration, we say that a neuron is dead. Activation rate in an epoch is a measure of layers entropy (high activation - high entropy), while death rate is a measure of overcapacity, because some neurons are not used.

This lens lets you investigate those values. In addition it allows to plot worst activation and death rates accross the whole model into one big warning plot.

Parameters:
  • inplace (bool = True) – Flag indicating if computation should be done in-place or in-memory.

  • record_eval (bool = False) – Flag indicating if data collected during evaluation should be ignored.

  • evaluation_from_grad (bool = False) – Flag indicating if evaluation should be decided from torch.is_grad_enabled() or module.training.

  • activation (bool = True) – Flag indicating if activation function layers’ data should be collected and displayed.

  • dropout (bool = True,) – Flag indicating if dropout layers’ data should be collected and displayed.

  • include (Iterable[Type[Module]] = tuple()) – Additional layer types to include for inspection.

  • exclude (Iterable[Type[Module]] = tuple()) – Layer types to exclude from expection. Overrides all settings.

  • channel_last (bool = False) – If True, expects layer outputs in [batch, seq_len, ..., features] format where the feature/channel dimension is last (e.g. transformer outputs). If False (default), expects PyTorch’s standard [batch, features, spatial_dims, ...] format.

  • warning_plot (bool = True) – Flag indicating if big warning plot should be added.

  • activation_aggregation (str = 'mean') – Aggregation method used to collect activation rate.

  • death_aggregation (str = 'mean') – Aggregation method used to collect death rate.

Examples

Default usage with no-grad validation

>>> inspector = PyTorchInspector(
...     lenses = [
...         OutputActivation(),
...     ],
...     module = mynet,
...     visualizer='matplotlib'
... )
>>>
>>> for epoch in range(N_EPOCHS):
...     for data, label in train_dataloader:
...         optimizer.zero_grad()
...         prediction = mynet(data)
...         loss = loss_fn(prediction, label)
...         loss.backward()
...         optimizer.step()
...
...     with torch.no_grad(): # outputs inside this block are not recorded
...         for data, label in val_dataloader:
...             prediction = mynet(data)
...             loss = loss_fn(prediction, label)
...     inspector.tick_epoch()
>>>
>>> inspector.visualizer.show_fig()
detach_from_module()[source]#

Detaches lens from module.

Detaches gatherers and resets inner state.

finalize_epoch()[source]#

Finaizes computations done through epoch.

Aggregates activations and death rates according to activation_aggregation and death_aggregation and computes worst activation (minimal) and worst death rates (maximal).

introduce_tags(vizualizer: AbstractVisualizer)[source]#

Introduces lens’s plots to visualizer.

Intorduces one small plot ‘Output Activations’, where per layer data is plotted, its type is PROBABILITY. If warning_plot is True also registers a big PROBABILITY plot ‘Warning Output Activations’.

Parameters:

visualzier (AbstractVisualizer) – A visualizer object to pass tag attributes to.

register_foreign_preprocessor(ext_ppr: AbstractPreprocessor, inspector_state)[source]#

Does not interact with foreign preprocessor.

register_leaf_module(module: Module, module_name: str, inspector_state)[source]#

Registers (or ignores) module.

Registers modules guided by activation and dropout flags during initialization and includes all modules of types mentioned in include. Exclusion by exclude parameter overrides every other configuration.

Parameters:
  • module (torch.nn.Module) – The module object to hook gatherers onto.

  • module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.

reset_epoch()[source]#

Resets inner state.

Resets data computed during last epoch and resets preprocessors.

vizualize(vizualizer: AbstractVisualizer, epoch: int)[source]#

Passes computed data to visualizer.

Passes dictionary of per layer data to ‘Output Activations’, the dictionary may look something like this.

OrderedDict([
    ('relu1',   {'activation_rate' : 0.8, 'death_rate' : 0.3}),
    ('dropout', {'activation_rate' : 0.9, 'death_rate' : 0.3}),
    ('relu2',   {'activation_rate' : 0.2, 'death_rate' : 0.1})
])

If warning plot needs to be plotted passes a dictionary described below to ‘Warning Output Activation’

OrderedDict([
    ('Warning Output Activation', {
        'worst activation_rate' : 0.2,
        'worst death_rate'      : 0.3
    })
])
Parameters:
  • visualizer (AbstractVisualizer) – The visualizer object responsbile for drawing plots.

  • epoch (int) – Computation’s epoch number.