OutputNorm#
- class monitorch.lens.OutputNorm(inplace: bool = True, record_eval: bool = False, evaluation_from_grad: bool = False, normalize_by_size: bool = False, log_scale: bool = False, activation: bool = True, include: Iterable[type[Module]] = (), exclude: Iterable[type[Module]] = (), channel_last: bool = False, comparison_plot: bool = True, comparison_aggregation: str | None = None, line_aggregation: str | Iterable[str] = 'mean', range_aggregation: str | Iterable[str] | None = ('std', 'min-max'))[source]#
Bases:
AbstractLensA lens to examine norm of layer outputs.
Computes L2-norm or root-mean-square of outputs produced during forward pass through module. Lens draws a small plot for each layer selected during initialization, optionally draws comparison plot between all layers.
- Parameters:
inplace (bool = True) – Flag indicating if computation should be done in-place or in-memory.
record_eval (bool = False) – Flag indicating if data collected during evaluation should be ignored.
evaluation_from_grad (bool = False) – Flag indicating if evaluation should be decided from torch.is_grad_enabled() or module.training.
normalize_by_size (bool = False) – Flag indicating if output norm should be divided by root of number of elements, thus obtaining RMS of output.
log_scale (bool = False) – Flag indicating if logarithmic scale should be used.
activation (bool = True) – Flag indicating if activation function layers’ data should be collected and displayed.
include (Iterable[Type[Module]] = tuple()) – Additional layer types to include for inspection.
exclude (Iterable[Type[Module]] = tuple()) – Layer types to exclude from expection. Overrides all settings.
channel_last (bool = False) – If
True, expects layer outputs in[batch, seq_len, ..., features]format where the feature/channel dimension is last (e.g. transformer outputs). IfFalse(default), expects PyTorch’s standard[batch, features, spatial_dims, ...]format.comparison_plot (bool = True) – Flag indicating if big comparison plot should be drawn.
comparison_aggregation (str|None = None) – Epoch level aggregation used on every output sequence for comparison plot. Default is the same as the first
line_aggregation.line_aggregation (str|Iterable[str] = 'mean') – Aggregation method for lines in plots.
range_aggregation (str|Iterable[str]|None = ('std', 'min-max')) – Aggregation method for bands in plots.
Examples
Default usage with no-grad validation
>>> inspector = PyTorchInspector( ... lenses = [ ... OutputNorm(), ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... ... with torch.no_grad(): # outputs inside this block are not recorded ... for data, label in val_dataloader: ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... inspector.tick_epoch() >>> >>> inspector.visualizer.show_fig()
- finalize_epoch()[source]#
Finaizes computations done through epoch.
Aggregates parameter gradient norms and optionally inner product according to
line_aggregationandrange_aggregation.
- introduce_tags(vizualizer: AbstractVisualizer)[source]#
Introduces lens’s plots to visualizer.
Intorduces one small plot ‘Output Norm’, where per layer data is plotted, its type is
NUMERICAL. Ifcomparison_plotisTruealso registers a bigRELATIONSplot ‘#AGGREGATION_METHOD Output [Log] Norm Comparison’ tweaked by initialization parameters.- Parameters:
visualzier (AbstractVisualizer) – A visualizer object to pass tag attributes to.
- register_foreign_preprocessor(_: AbstractPreprocessor, inspector_state)[source]#
Does not interact with foreign preprocessor.
- register_leaf_module(module: Module, module_name: str, inspector_state)[source]#
Registers (or ignores) module.
Registers modules guided by
activationflag set during initialization and includes all modules of types mentioned ininclude. Exclusion byexcludeparameter overrides every other configuration.- Parameters:
module (torch.nn.Module) – The module object to hook gatherers onto.
module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.
- reset_epoch()[source]#
Resets inner state.
Resets data computed during last epoch and resets preprocessors.
- vizualize(vizualizer: AbstractVisualizer, epoch: int)[source]#
Passes computed data to visualizer.
Passes dictionary of per layer data to ‘Output Norm’, the dictionary may look something like this.
OrderedDict([ ('relu1', {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}), ('relu2', {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}), ])
If comparison plot needs to be plotted passes a dictionary described below to ‘#AGGREGATION_METHOD Output [Log] Norm Comparison’
OrderedDict([ ('Mean Output Log Norm Comparison', { 'relu1' : 0.7, 'relu2' : 0.3 }) ])
- Parameters:
visualizer (AbstractVisualizer) – The visualizer object responsbile for drawing plots.
epoch (int) – Computation’s epoch number.