ParameterGradientActivation#
- class monitorch.lens.ParameterGradientActivation(inplace: bool = True, warning_plot: bool = True, parameters: str | Iterable[str] = ('weight', 'bias'), activation_aggregation: str = 'mean', death_aggregation: str = 'mean')[source]#
Bases:
AbstractLensA lens to inspect neuron activation through parameter gradients.
Neuron is active if its gradient is non-zero, if neuron is inactive for the whole batch iteration, we say that a neuron is dead. Activation rate in an epoch is a measure of layers entropy (high activation - high entropy), while death rate is a measure of overcapacity, because some neurons are not used.
This lens lets you investigate those values. In addition it allows to plot worst activation and death rates accross the whole model into one big warning plot.
- Parameters:
inplace (bool = True) – Flag indicating if computation should be done in-place or in-memory.
parameters (str|Iterable[str] = ('weight', 'bias')) – Parameters which gradient will be studied.
warning_plot (bool = True) – Flag indicating if big warning plot should be added.
activation_aggregation (str = 'mean') – Aggregation method used to collect activation rate.
death_aggregation (str = 'mean') – Aggregation method used to collect death rate.
Examples
Default usage is as simple as just mentioning
ParameterGradientActivationto inspector.>>> inspector = PyTorchInspector( ... lenses = [ ... ParameterGradientActivation(), ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... inspector.tick_epoch() >>> >>> inspector.visualizer.show_fig()
- finalize_epoch()[source]#
Finalizes computations done thorugh epoch.
Aggregates activations and death rates according to
activation_aggregationanddeath_aggregationand computes worst activation (minimal) and worst death rates (maximal) for every parameter.
- introduce_tags(vizualizer: AbstractVisualizer)[source]#
Introduces lens’s plots to visualizer.
For every parameter creates a small probability tag ‘#PARAMETER_NAME Gradient Activation’, if warning plot is on, also adds a big warning probability plot for every parameter.
- Parameters:
visualzier (AbstractVisualizer) – A visualizer object to pass tag attributes to.
- register_foreign_preprocessor(ext_ppr: AbstractPreprocessor, inspector_state)[source]#
Does not interact with foreign preprocessor.
- register_leaf_module(module: Module, module_name: str, inspector_state)[source]#
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
- Parameters:
module (torch.nn.Module) – The module object to hook gatherers onto.
module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.
- register_non_leaf_module(module: Module, module_name: str, inspector_state)[source]#
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
- Parameters:
module (torch.nn.Module) – The module object to hook gatherers onto.
module_name (str) – Name of the module, module’s information will be passed to visaulizer under this name.
- reset_epoch()[source]#
Resets inner state.
Resets data computed during last epoch and resets preprocessors.
- vizualize(vizualizer: AbstractVisualizer, epoch: int)[source]#
Passes computed data to visualizer.
For every parameter listed during initialization. Passes dictionary of per layer data to ‘#PARAMETER_NAME Gradient Activations’, the dictionary may look something like this.
OrderedDict([ ('lin1', {'activation_rate' : 0.8, 'death_rate' : 0.3}), ('lin2', {'activation_rate' : 0.5, 'death_rate' : 0.3}), ])
If warning plot needs to be plotted passes a dictionary described below to ‘Warning #PARAMETER_NAME Gradient Activation’
OrderedDict([ ('Warning Weight Gradient Activation', { 'worst activation_rate' : 0.2, 'worst death_rate' : 0.3 }) ])
- Parameters:
visualizer (AbstractVisualizer) – The visualizer object responsbile for drawing plots.
epoch (int) – Computation’s epoch number.