ParameterNorm#
- class monitorch.lens.ParameterNorm(inplace: bool = True, parameters: Iterable[str] = ('weight', 'bias'), normalize_by_size: bool = False, log_scale: bool = False, comparison_plot: bool = True, aggregation_method: str = 'mean')[source]#
Bases:
AbstractLensA lens to collect parameter norm.
Computes L2-norm or root-mean-square on explicit lens call or epoch tick. Lens draws a small plot for each layer selected during initialization, optionally draws comparison plot between all layers.
- Parameters:
inplace (bool = True) – Flag indicating if computation should be done in-place or in-memory.
parameters (Iterable[str] = ('weight', 'bias')) – Parameters which the norm or rms will be computed.
normalize_by_size (bool = False) – Flag indicating if parameter norm should be divided by root of number of elements, thus obtaining RMS of parameter.
log_scale (bool = False) – Flag indicating if logarithmic scale should be used.
comparison_plot (bool = True) – Flag indicating if big comparison plot should be drawn.
aggregation_method (str = 'mean') – Aggregation method for lines in plots.
Examples
Default usage is shown below.
>>> inspector = PyTorchInspector( ... lenses = [ ... ParameterNorm(), ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... ... inspector.tick_epoch() >>> >>> inspector.visualizer.show_fig()
To collect data more often use
collect_data().>>> pnorm_lens = ParameterNorm() >>> inspector = PyTorchInspector( ... lenses = [ ... pnorm_lens, ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... pnorm_lens.collect_data() ... ... inspector.tick_epoch() >>> >>> inspector.visualizer.show_fig()
- finalize_epoch()[source]#
Finaizes computations done through epoch.
Aggregates parameter norms according to
aggregation_methodand computes comparison values.
- introduce_tags(vizualizer: AbstractVisualizer)[source]#
Introduces lens’s plots to visualizer.
For every parameter listed during initialization creates a small numerical plot ‘#PARAMETER_NAME Norm’ optionally creates a big comparison plot ‘#PARAMETER_NAME [Log] Norm Comparisson’.
- Parameters:
visualzier (AbstractVisualizer) – A visualizer object to pass tag attributes to.
- register_foreign_preprocessor(ext_ppr: AbstractPreprocessor, inspector_state)[source]#
Does not interact with foreign preprocessor.
- register_leaf_module(module: Module, module_name: str, inspector_state)[source]#
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
- Parameters:
module (torch.nn.Module) – The module object to hook gatherers onto.
module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.
- register_non_leaf_module(module: Module, module_name: str, inspector_state)[source]#
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
- Parameters:
module (torch.nn.Module) – The module object to hook gatherers onto.
module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.
- reset_epoch()[source]#
Resets inner state.
Resets data computed during last epoch and resets preprocessors.
- vizualize(vizualizer: AbstractVisualizer, epoch: int)[source]#
Passes computed data to visualizer.
Passes dictionary of per layer data to ‘#PARAMETER_NAME Output Norm’, the dictionary may look something like this.
OrderedDict([ ('lin1', {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}), ('lin2', {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}), ])
If comparison plot needs to be plotted passes a dictionary described below to ‘#PARAMETER [Log] Norm Comparison’
OrderedDict([ ('Weight Norm Comparison', { 'lin1' : 0.7, 'lin2' : 0.3 }) ])
- Parameters:
visualizer (AbstractVisualizer) – The visualizer object responsbile for drawing plots.
epoch (int) – Computation’s epoch number.