ParameterNorm#

class monitorch.lens.ParameterNorm(inplace: bool = True, parameters: Iterable[str] = ('weight', 'bias'), normalize_by_size: bool = False, log_scale: bool = False, comparison_plot: bool = True, aggregation_method: str = 'mean')[source]#

Bases: AbstractLens

A lens to collect parameter norm.

Computes L2-norm or root-mean-square on explicit lens call or epoch tick. Lens draws a small plot for each layer selected during initialization, optionally draws comparison plot between all layers.

Parameters:
  • inplace (bool = True) – Flag indicating if computation should be done in-place or in-memory.

  • parameters (Iterable[str] = ('weight', 'bias')) – Parameters which the norm or rms will be computed.

  • normalize_by_size (bool = False) – Flag indicating if parameter norm should be divided by root of number of elements, thus obtaining RMS of parameter.

  • log_scale (bool = False) – Flag indicating if logarithmic scale should be used.

  • comparison_plot (bool = True) – Flag indicating if big comparison plot should be drawn.

  • aggregation_method (str = 'mean') – Aggregation method for lines in plots.

Examples

Default usage is shown below.

>>> inspector = PyTorchInspector(
...     lenses = [
...         ParameterNorm(),
...     ],
...     module = mynet,
...     visualizer='matplotlib'
... )
>>>
>>> for epoch in range(N_EPOCHS):
...     for data, label in train_dataloader:
...         optimizer.zero_grad()
...         prediction = mynet(data)
...         loss = loss_fn(prediction, label)
...         loss.backward()
...         optimizer.step()
...
...     inspector.tick_epoch()
>>>
>>> inspector.visualizer.show_fig()

To collect data more often use collect_data().

>>> pnorm_lens = ParameterNorm()
>>> inspector = PyTorchInspector(
...     lenses = [
...         pnorm_lens,
...     ],
...     module = mynet,
...     visualizer='matplotlib'
... )
>>>
>>> for epoch in range(N_EPOCHS):
...     for data, label in train_dataloader:
...         optimizer.zero_grad()
...         prediction = mynet(data)
...         loss = loss_fn(prediction, label)
...         loss.backward()
...         optimizer.step()
...         pnorm_lens.collect_data()
...
...     inspector.tick_epoch()
>>>
>>> inspector.visualizer.show_fig()
collect_data()[source]#

Calls gatherers to collect data.

detach_from_module()[source]#

Detaches lens from module.

Detaches gatherers and resets inner state.

finalize_epoch()[source]#

Finaizes computations done through epoch.

Aggregates parameter norms according to aggregation_method and computes comparison values.

introduce_tags(vizualizer: AbstractVisualizer)[source]#

Introduces lens’s plots to visualizer.

For every parameter listed during initialization creates a small numerical plot ‘#PARAMETER_NAME Norm’ optionally creates a big comparison plot ‘#PARAMETER_NAME [Log] Norm Comparisson’.

Parameters:

visualzier (AbstractVisualizer) – A visualizer object to pass tag attributes to.

register_foreign_preprocessor(ext_ppr: AbstractPreprocessor, inspector_state)[source]#

Does not interact with foreign preprocessor.

register_leaf_module(module: Module, module_name: str, inspector_state)[source]#

Registers (or ignores) module.

Registers any module that has all of the parameters listed during initialization.

Parameters:
  • module (torch.nn.Module) – The module object to hook gatherers onto.

  • module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.

register_non_leaf_module(module: Module, module_name: str, inspector_state)[source]#

Registers (or ignores) module.

Registers any module that has all of the parameters listed during initialization.

Parameters:
  • module (torch.nn.Module) – The module object to hook gatherers onto.

  • module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.

reset_epoch()[source]#

Resets inner state.

Resets data computed during last epoch and resets preprocessors.

vizualize(vizualizer: AbstractVisualizer, epoch: int)[source]#

Passes computed data to visualizer.

Passes dictionary of per layer data to ‘#PARAMETER_NAME Output Norm’, the dictionary may look something like this.

OrderedDict([
    ('lin1',   {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}),
    ('lin2',   {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}),
])

If comparison plot needs to be plotted passes a dictionary described below to ‘#PARAMETER [Log] Norm Comparison’

OrderedDict([
    ('Weight Norm Comparison', {
        'lin1' : 0.7,
        'lin2' : 0.3
    })
])
Parameters:
  • visualizer (AbstractVisualizer) – The visualizer object responsbile for drawing plots.

  • epoch (int) – Computation’s epoch number.