from collections import OrderedDict
from collections.abc import Iterable
from torch.nn import Module
from monitorch.gatherer import ParameterGradientGatherer
from monitorch.numerical import extract_point
from monitorch.preprocessor import AbstractPreprocessor, GradientActivation
from monitorch.visualizer import AbstractVisualizer, TagAttributes, TagType
from .abstract_lens import AbstractLens
[docs]
class ParameterGradientActivation(AbstractLens):
"""
A lens to inspect neuron activation through parameter gradients.
Neuron is active if its gradient is non-zero, if neuron is inactive for the whole batch iteration,
we say that a neuron is dead. Activation rate in an epoch is a measure of layers entropy (high activation - high entropy),
while death rate is a measure of overcapacity, because some neurons are not used.
This lens lets you investigate those values.
In addition it allows to plot worst activation and death rates accross the whole model into one big warning plot.
Parameters
----------
inplace : bool = True
Flag indicating if computation should be done in-place or in-memory.
parameters : str|Iterable[str] = ('weight', 'bias')
Parameters which gradient will be studied.
warning_plot : bool = True
Flag indicating if big warning plot should be added.
activation_aggregation : str = 'mean'
Aggregation method used to collect activation rate.
death_aggregation : str = 'mean'
Aggregation method used to collect death rate.
Examples
--------
Default usage is as simple as just mentioning ``ParameterGradientActivation`` to inspector.
>>> inspector = PyTorchInspector(
... lenses = [
... ParameterGradientActivation(),
... ],
... module = mynet,
... visualizer='matplotlib'
... )
>>>
>>> for epoch in range(N_EPOCHS):
... for data, label in train_dataloader:
... optimizer.zero_grad()
... prediction = mynet(data)
... loss = loss_fn(prediction, label)
... loss.backward()
... optimizer.step()
... inspector.tick_epoch()
>>>
>>> inspector.visualizer.show_fig()
"""
_BIG_TAG_NAME = 'Warning Gradient Activations'
def __init__(
self,
inplace: bool = True,
warning_plot: bool = True,
parameters: str | Iterable[str] = ('weight', 'bias'),
activation_aggregation: str = 'mean',
death_aggregation: str = 'mean',
):
self._warning_plot = warning_plot
if warning_plot:
self._warning_data = {}
self._preprocessors = OrderedDict([(parameter, GradientActivation(inplace=inplace, death=True)) for parameter in parameters])
self._gatherers = []
self._data: dict[str, OrderedDict[str, dict[str, float]]] = {parameter_name: OrderedDict() for parameter_name in parameters}
self._activation_aggregation = activation_aggregation
self._death_aggregation = death_aggregation
[docs]
def register_leaf_module(self, module: Module, module_name: str, inspector_state):
"""
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
Parameters
----------
module : torch.nn.Module
The module object to hook gatherers onto.
module_name : str
Name of the module, module's information will be passed to visaulizer under this name.
"""
self._register_module(module, module_name, inspector_state)
[docs]
def register_non_leaf_module(self, module: Module, module_name: str, inspector_state):
"""
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
Parameters
----------
module : torch.nn.Module
The module object to hook gatherers onto.
module_name : str
Name of the module, module's information will be passed to visaulizer under this name.
"""
self._register_module(module, module_name, inspector_state)
def _register_module(self, module: Module, module_name: str, inspector_state):
"""
Generic function called from :meth:`register_non_leaf_module` and :meth:`register_leaf_module`
Parameters
----------
module : torch.nn.Module
The module object to hook gatherers onto.
module_name : str
Name of the module, module's information will be passed to visaulizer under this name.
"""
if not all(hasattr(module, parameter_name) and getattr(module, parameter_name) is not None and not getattr(module, parameter_name).requires_grad for parameter_name in self._preprocessors):
return
for parameter, preprocessor in self._preprocessors.items():
pgg = ParameterGradientGatherer(parameter, module, [preprocessor], module_name, inspector_state=inspector_state)
self._gatherers.append(pgg)
[docs]
def detach_from_module(self):
"""
Detaches lens from module.
Detaches gatherers and resets inner state.
"""
for gatherer in self._gatherers:
gatherer.detach()
self._gatherers = []
self._data: dict[str, OrderedDict[str, dict[str, float]]] = {parameter_name: OrderedDict() for parameter_name in self._data.keys()}
if self._warning_plot:
self._warning_data = {}
[docs]
def register_foreign_preprocessor(self, ext_ppr: AbstractPreprocessor, inspector_state):
"""Does not interact with foreign preprocessor."""
pass
[docs]
def finalize_epoch(self):
"""
Finalizes computations done thorugh epoch.
Aggregates activations and death rates according to ``activation_aggregation`` and ``death_aggregation``
and computes worst activation (minimal) and worst death rates (maximal) for every parameter.
"""
worst_act = float('+inf')
worst_death = float('-inf')
for parameter_name, preprocessor in self._preprocessors.items():
tag_dict = self._data.setdefault(parameter_name, OrderedDict())
for module_name, (act_rate, death) in preprocessor.value.items():
val_dict = tag_dict.setdefault(module_name, {})
val_dict['activation_rate'] = extract_point(act_rate, self._activation_aggregation)
val_dict['death_rate'] = extract_point(death, self._death_aggregation)
worst_act = min(worst_act, val_dict['activation_rate'])
worst_death = max(worst_death, val_dict['death_rate'])
self._data[parameter_name] = OrderedDict(reversed(tag_dict.items()))
if self._warning_plot:
self._warning_data['worst activation_rate'] = worst_act
self._warning_data['worst death_rate'] = worst_death
[docs]
def vizualize(self, vizualizer: AbstractVisualizer, epoch: int):
"""
Passes computed data to visualizer.
For every parameter listed during initialization.
Passes dictionary of per layer data to '#PARAMETER_NAME Gradient Activations', the dictionary
may look something like this.
::
OrderedDict([
('lin1', {'activation_rate' : 0.8, 'death_rate' : 0.3}),
('lin2', {'activation_rate' : 0.5, 'death_rate' : 0.3}),
])
If warning plot needs to be plotted passes a dictionary described below to 'Warning #PARAMETER_NAME Gradient Activation'
::
OrderedDict([
('Warning Weight Gradient Activation', {
'worst activation_rate' : 0.2,
'worst death_rate' : 0.3
})
])
Parameters
----------
visualizer : AbstractVisualizer
The visualizer object responsbile for drawing plots.
epoch : int
Computation's epoch number.
"""
for parameter_name in self._preprocessors:
vizualizer.plot_probabilities(epoch, f'{parameter_name} Gradient Activation'.title(), self._data[parameter_name])
if self._warning_plot:
vizualizer.plot_probabilities(epoch, ParameterGradientActivation._BIG_TAG_NAME, OrderedDict([(ParameterGradientActivation._BIG_TAG_NAME, self._warning_data)]))
[docs]
def reset_epoch(self):
"""
Resets inner state.
Resets data computed during last epoch and resets preprocessors.
"""
for preprocessor in self._preprocessors.values():
preprocessor.reset()