OutputGradientGeometry#
- class monitorch.lens.OutputGradientGeometry(inplace: bool = True, normalize_by_size: bool = False, log_scale: bool = False, compute_correlation: bool = True, skip_activation: bool = True, line_aggregation: str | Iterable[str] = 'mean', range_aggregation: str | Iterable[str] | None = ('std', 'min-max'))[source]#
Bases:
AbstractLensLens to examine geometry of gradients with respect to layer outputs.
Computes L2-norm or root-mean-square of gradients on every backward pass through layer. Optionally computes correlation between gradients from two consecutive backward passes.
Computing correlation requires gradients from both epochs, hence the gradient will be saved after the computation is finished. It drives space consumption linearly by size of studied outputs.
- Parameters:
inplace (bool = True) – Flag indicating if computation should be done in-place or in-memory.
normalize_by_size (bool = False) – Flag indicating if output norm should be divided by root of number of elements, thus obtaining RMS of output.
log_scale (bool = False) – Flag indicating if logarithmic scale should be used.
compute_correlation (bool = True) – Flag indicating if correlation between gradients from consecutive backward passes should be computed.
skip_activation (bool = True) – Flag indicating if lens should NOT register activation layers.
line_aggregation (str|Iterable[str] = 'mean') – Aggregation method for lines in plots.
range_aggregation (str|Iterable[str]|None = ('std', 'min-max')) – Aggregation method for bands in plots.
Examples
Default usage is shown below.
>>> inspector = PyTorchInspector( ... lenses = [ ... OutputGradientGeometry(), ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... ... inspector.tick_epoch() >>> >>> inspector.visualizer.show_fig()
- finalize_epoch()[source]#
Finaizes computations done through epoch.
Aggregates output gradient norms and optionally correlation according to
line_aggregationandrange_aggregation.
- introduce_tags(vizualizer: AbstractVisualizer)[source]#
Introduces lens’s plots to visualizer.
Registers a small numerical plot ‘Output Gradient Norm’ and optionally registers a small numerical plot ‘Output Gradient Correlation’.
- Parameters:
visualzier (AbstractVisualizer) – A visualizer object to pass tag attributes to.
- register_foreign_preprocessor(ext_ppr: AbstractPreprocessor, inspector_state)[source]#
Does not interact with foreign preprocessor.
- register_leaf_module(module: Module, module_name: str, inspector_state)[source]#
Registers (or ignores) module.
If
skip_activationisTrue, then does not register activation modules, otherwise registers any module.- Parameters:
module (torch.nn.Module) – The module object to hook gatherers onto.
module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.
- reset_epoch()[source]#
Resets inner state.
Resets data computed during last epoch and resets preprocessors.
- vizualize(vizualizer: AbstractVisualizer, epoch: int)[source]#
Passes computed data to visualizer.
Passes dictionary of per layer data to ‘Output Gradient Norm’, the dictionary may look something like this.
OrderedDict([ ('lin1', {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}), ('relu1', {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}), ])
Gradient correlation dictionary looks the same.
- Parameters:
visualizer (AbstractVisualizer) – The visualizer object responsbile for drawing plots.
epoch (int) – Computation’s epoch number.