ParameterUpdateGeometry#
- class monitorch.lens.ParameterUpdateGeometry(optimizer: Optimizer | None, inplace: bool = True, normalize_by_size: bool = False, log_scale: bool = False, compute_correlation: bool = True, parameters: str | Iterable[str] = ('weight', 'bias'), line_aggregation: str | Iterable[str] = 'mean', range_aggregation: str | Iterable[str] | None = ('std', 'min-max'))[source]#
Bases:
AbstractLensLens to examine geometry of parameter updates.
Computes L2-norm or root-mean-square of parameter updates on every optimizer step call. Optionally computes correlation between parameter updates from two consecutive steps.
Computing correlation requires parameter update from both epochs, hence the update will be saved after the computation is finished. It drives space consumption linearly by size of studied parameters.
- Parameters:
optimizer (torch.optim.Optimizer | None) – Optimizer, its step calls will be used for data collecection. If None is provided the lens must explicitly called via
inspect_update().inplace (bool = True) – Flag indicating if computation should be done in-place or in-memory.
normalize_by_size (bool = False) – Flag indicating if output norm should be divided by root of number of elements, thus obtaining RMS of output.
log_scale (bool = False) – Flag indicating if logarithmic scale should be used.
compute_correlation (bool = True) – Flag indicating if correlation between updates from consecutive optimizer steps should be computed.
parameters (str|Iterable[str] = ('weight', 'bias')) – Parameters which updates will be studied.
line_aggregation (str|Iterable[str] = 'mean') – Aggregation method for lines in plots.
range_aggregation (str|Iterable[str]|None = ('std', 'min-max')) – Aggregation method for bands in plots.
Examples
Default usage is shown below.
>>> optimizer = torch.optim.AdamW(mynet.parameters()) >>> inspector = pytorchinspector( ... lenses = [ ... ParameterUpdateGeometry(optimizer), ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... ... inspector.tick_epoch() >>> >>> inspector.visualizer.show_fig()
- detach_from_module()[source]#
Detaches lens from module. Detaches from optimizer setting it to
NoneDetaches gatherers and resets inner state.
- finalize_epoch()[source]#
Finaizes computations done through epoch.
Aggregates parameter updates’ norms and optionally correlation according to
line_aggregationandrange_aggregation.
- inspect_update()[source]#
Utility function to collect updates. Used primarily to collect data if no optimizer is provided.
Calls every gatherer without any arguments with no regard to underline type.
- introduce_tags(vizualizer: AbstractVisualizer)[source]#
Introduces lens’s plots to visualizer.
For every parameter listed during initialization creates a small numerical plot ‘#PARAMETER_NAME Update Norm’ optionally creates a big comparison plot ‘#PARAMETER_NAME Update Correlation’.
- Parameters:
visualzier (AbstractVisualizer) – A visualizer object to pass tag attributes to.
- register_foreign_preprocessor(ext_ppr: AbstractPreprocessor, inspector_state)[source]#
Does not interact with foreign preprocessor.
- register_leaf_module(module: Module, module_name: str, inspector_state)[source]#
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
- Parameters:
module (torch.nn.Module) – The module object to hook gatherers onto.
module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.
- register_non_leaf_module(module: Module, module_name: str, inspector_state)[source]#
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
- Parameters:
module (torch.nn.Module) – The module object to hook gatherers onto.
module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.
- reset_epoch()[source]#
Resets inner state.
Resets data computed during last epoch and resets preprocessors.
- vizualize(vizualizer: AbstractVisualizer, epoch: int)[source]#
Passes computed data to visualizer.
Passes dictionary of per layer data to ‘#PARAMETER_NAME Update Norm’, the dictionary may look something like this.
OrderedDict([ ('lin1', {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}), ('lin2', {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}), ])
Update correlation dictionary looks the same.
- Parameters:
visualizer (AbstractVisualizer) – The visualizer object responsbile for drawing plots.
epoch (int) – Computation’s epoch number.