OutputNorm#

class monitorch.lens.OutputNorm(inplace: bool = True, record_eval: bool = False, evaluation_from_grad: bool = False, normalize_by_size: bool = False, log_scale: bool = False, activation: bool = True, include: Iterable[type[Module]] = (), exclude: Iterable[type[Module]] = (), channel_last: bool = False, comparison_plot: bool = True, comparison_aggregation: str | None = None, line_aggregation: str | Iterable[str] = 'mean', range_aggregation: str | Iterable[str] | None = ('std', 'min-max'))[source]#

Bases: AbstractLens

A lens to examine norm of layer outputs.

Computes L2-norm or root-mean-square of outputs produced during forward pass through module. Lens draws a small plot for each layer selected during initialization, optionally draws comparison plot between all layers.

Parameters:
  • inplace (bool = True) – Flag indicating if computation should be done in-place or in-memory.

  • record_eval (bool = False) – Flag indicating if data collected during evaluation should be ignored.

  • evaluation_from_grad (bool = False) – Flag indicating if evaluation should be decided from torch.is_grad_enabled() or module.training.

  • normalize_by_size (bool = False) – Flag indicating if output norm should be divided by root of number of elements, thus obtaining RMS of output.

  • log_scale (bool = False) – Flag indicating if logarithmic scale should be used.

  • activation (bool = True) – Flag indicating if activation function layers’ data should be collected and displayed.

  • include (Iterable[Type[Module]] = tuple()) – Additional layer types to include for inspection.

  • exclude (Iterable[Type[Module]] = tuple()) – Layer types to exclude from expection. Overrides all settings.

  • channel_last (bool = False) – If True, expects layer outputs in [batch, seq_len, ..., features] format where the feature/channel dimension is last (e.g. transformer outputs). If False (default), expects PyTorch’s standard [batch, features, spatial_dims, ...] format.

  • comparison_plot (bool = True) – Flag indicating if big comparison plot should be drawn.

  • comparison_aggregation (str|None = None) – Epoch level aggregation used on every output sequence for comparison plot. Default is the same as the first line_aggregation.

  • line_aggregation (str|Iterable[str] = 'mean') – Aggregation method for lines in plots.

  • range_aggregation (str|Iterable[str]|None = ('std', 'min-max')) – Aggregation method for bands in plots.

Examples

Default usage with no-grad validation

>>> inspector = PyTorchInspector(
...     lenses = [
...         OutputNorm(),
...     ],
...     module = mynet,
...     visualizer='matplotlib'
... )
>>>
>>> for epoch in range(N_EPOCHS):
...     for data, label in train_dataloader:
...         optimizer.zero_grad()
...         prediction = mynet(data)
...         loss = loss_fn(prediction, label)
...         loss.backward()
...         optimizer.step()
...
...     with torch.no_grad(): # outputs inside this block are not recorded
...         for data, label in val_dataloader:
...             prediction = mynet(data)
...             loss = loss_fn(prediction, label)
...     inspector.tick_epoch()
>>>
>>> inspector.visualizer.show_fig()
detach_from_module()[source]#

Detaches lens from module.

Detaches gatherers and resets inner state.

finalize_epoch()[source]#

Finaizes computations done through epoch.

Aggregates parameter gradient norms and optionally inner product according to line_aggregation and range_aggregation.

introduce_tags(vizualizer: AbstractVisualizer)[source]#

Introduces lens’s plots to visualizer.

Intorduces one small plot ‘Output Norm’, where per layer data is plotted, its type is NUMERICAL. If comparison_plot is True also registers a big RELATIONS plot ‘#AGGREGATION_METHOD Output [Log] Norm Comparison’ tweaked by initialization parameters.

Parameters:

visualzier (AbstractVisualizer) – A visualizer object to pass tag attributes to.

register_foreign_preprocessor(_: AbstractPreprocessor, inspector_state)[source]#

Does not interact with foreign preprocessor.

register_leaf_module(module: Module, module_name: str, inspector_state)[source]#

Registers (or ignores) module.

Registers modules guided by activation flag set during initialization and includes all modules of types mentioned in include. Exclusion by exclude parameter overrides every other configuration.

Parameters:
  • module (torch.nn.Module) – The module object to hook gatherers onto.

  • module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.

reset_epoch()[source]#

Resets inner state.

Resets data computed during last epoch and resets preprocessors.

vizualize(vizualizer: AbstractVisualizer, epoch: int)[source]#

Passes computed data to visualizer.

Passes dictionary of per layer data to ‘Output Norm’, the dictionary may look something like this.

OrderedDict([
    ('relu1',   {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}),
    ('relu2',   {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}),
])

If comparison plot needs to be plotted passes a dictionary described below to ‘#AGGREGATION_METHOD Output [Log] Norm Comparison’

OrderedDict([
    ('Mean Output Log Norm Comparison', {
        'relu1' : 0.7,
        'relu2' : 0.3
    })
])
Parameters:
  • visualizer (AbstractVisualizer) – The visualizer object responsbile for drawing plots.

  • epoch (int) – Computation’s epoch number.