from collections import OrderedDict
from collections.abc import Iterable
from torch.nn import Module
from monitorch.gatherer import EpochModuleGatherer
from monitorch.numerical import extract_point
from monitorch.preprocessor import AbstractPreprocessor
from monitorch.preprocessor import ParameterNorm as ParameterNormPreprocessor
from monitorch.visualizer import AbstractVisualizer, TagAttributes, TagType
from .abstract_lens import AbstractLens
[docs]
class ParameterNorm(AbstractLens):
"""
A lens to collect parameter norm.
Computes L2-norm or root-mean-square on explicit lens call or epoch tick.
Lens draws a small plot for each layer selected during initialization, optionally draws comparison plot
between all layers.
Parameters
----------
inplace : bool = True
Flag indicating if computation should be done in-place or in-memory.
parameters : Iterable[str] = ('weight', 'bias')
Parameters which the norm or rms will be computed.
normalize_by_size : bool = False
Flag indicating if parameter norm should be divided by root of number of elements, thus obtaining RMS of parameter.
log_scale : bool = False
Flag indicating if logarithmic scale should be used.
comparison_plot : bool = True
Flag indicating if big comparison plot should be drawn.
aggregation_method : str = 'mean'
Aggregation method for lines in plots.
Examples
--------
Default usage is shown below.
>>> inspector = PyTorchInspector(
... lenses = [
... ParameterNorm(),
... ],
... module = mynet,
... visualizer='matplotlib'
... )
>>>
>>> for epoch in range(N_EPOCHS):
... for data, label in train_dataloader:
... optimizer.zero_grad()
... prediction = mynet(data)
... loss = loss_fn(prediction, label)
... loss.backward()
... optimizer.step()
...
... inspector.tick_epoch()
>>>
>>> inspector.visualizer.show_fig()
To collect data more often use :meth:`collect_data`.
>>> pnorm_lens = ParameterNorm()
>>> inspector = PyTorchInspector(
... lenses = [
... pnorm_lens,
... ],
... module = mynet,
... visualizer='matplotlib'
... )
>>>
>>> for epoch in range(N_EPOCHS):
... for data, label in train_dataloader:
... optimizer.zero_grad()
... prediction = mynet(data)
... loss = loss_fn(prediction, label)
... loss.backward()
... optimizer.step()
... pnorm_lens.collect_data()
...
... inspector.tick_epoch()
>>>
>>> inspector.visualizer.show_fig()
"""
def __init__(
self,
inplace: bool = True,
parameters: Iterable[str] = ('weight', 'bias'),
normalize_by_size: bool = False,
log_scale: bool = False,
comparison_plot: bool = True,
aggregation_method: str = 'mean',
):
self._parameters = list(parameters)
self._log_scale = log_scale
self._preprocessor = ParameterNormPreprocessor(self._parameters, normalize=normalize_by_size, inplace=inplace)
self._gatherers: list[EpochModuleGatherer] = []
self._data: OrderedDict[str, OrderedDict[str, dict[str, float]]] = OrderedDict([(parameter_name, OrderedDict()) for parameter_name in self._parameters])
self._aggregation_method = aggregation_method
self._comparison_plot = comparison_plot
if self._comparison_plot:
self._comparison_data: OrderedDict[str, OrderedDict[str, float]] = OrderedDict([(parameter_name, OrderedDict()) for parameter_name in self._parameters])
[docs]
def register_leaf_module(self, module: Module, module_name: str, inspector_state):
"""
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
Parameters
----------
module : torch.nn.Module
The module object to hook gatherers onto.
module_name : str
Name of the module, module's information will be passed to visaulizer under this name.
"""
self._register_module(module, module_name, inspector_state)
[docs]
def register_non_leaf_module(self, module: Module, module_name: str, inspector_state):
"""
Registers (or ignores) module.
Registers any module that has all of the parameters listed during initialization.
Parameters
----------
module : torch.nn.Module
The module object to hook gatherers onto.
module_name : str
Name of the module, module's information will be passed to visaulizer under this name.
"""
self._register_module(module, module_name, inspector_state)
def _register_module(self, module: Module, module_name: str, inspector_state):
"""
Generic function called from :meth:`register_non_leaf_module` and :meth:`register_leaf_module`
Parameters
----------
module : torch.nn.Module
The module object to hook gatherers onto.
module_name : str
Name of the module, module's information will be passed to visaulizer under this name.
"""
if not all(hasattr(module, parameter_name) and getattr(module, parameter_name) is not None for parameter_name in self._parameters):
return
gatherer = EpochModuleGatherer(module, [self._preprocessor], module_name, inspector_state=inspector_state)
self._gatherers.append(gatherer)
for parameter_name in self._parameters:
self._data[parameter_name][module_name] = {}
[docs]
def detach_from_module(self):
"""
Detaches lens from module.
Detaches gatherers and resets inner state.
"""
for gatherer in self._gatherers:
gatherer.detach()
self._gatherers = []
self._data: OrderedDict[str, OrderedDict[str, dict[str, float]]] = OrderedDict([(parameter_name, OrderedDict()) for parameter_name in self._parameters])
if self._comparison_plot:
self._comparison_data: OrderedDict[str, OrderedDict[str, float]] = OrderedDict([(parameter_name, OrderedDict()) for parameter_name in self._parameters])
[docs]
def register_foreign_preprocessor(self, ext_ppr: AbstractPreprocessor, inspector_state):
"""Does not interact with foreign preprocessor."""
pass
[docs]
def collect_data(self):
"""
Calls gatherers to collect data.
"""
for gatherer in self._gatherers:
gatherer()
[docs]
def finalize_epoch(self):
"""
Finaizes computations done through epoch.
Aggregates parameter norms according to ``aggregation_method``
and computes comparison values.
"""
self.collect_data()
for parameter_name in self._parameters:
comparison_dict: OrderedDict[str, float]
if self._comparison_plot:
comparison_dict = self._comparison_data[parameter_name]
tag_data_dict = self._data[parameter_name]
total_sum = 1e-7
for module_name, module_data in self._preprocessor.value.items():
pt_val = extract_point(module_data[parameter_name], self._aggregation_method)
tag_data_dict.setdefault(module_name, {})[self._aggregation_method] = pt_val
total_sum += pt_val
if self._comparison_plot:
comparison_dict[module_name] = pt_val
if self._comparison_plot:
for module_name in comparison_dict:
comparison_dict[module_name] /= total_sum
[docs]
def vizualize(self, vizualizer: AbstractVisualizer, epoch: int):
"""
Passes computed data to visualizer.
Passes dictionary of per layer data to '#PARAMETER_NAME Output Norm', the dictionary
may look something like this.
::
OrderedDict([
('lin1', {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}),
('lin2', {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}),
])
If comparison plot needs to be plotted passes a dictionary described below to '#PARAMETER [Log] Norm Comparison'
::
OrderedDict([
('Weight Norm Comparison', {
'lin1' : 0.7,
'lin2' : 0.3
})
])
Parameters
----------
visualizer : AbstractVisualizer
The visualizer object responsbile for drawing plots.
epoch : int
Computation's epoch number.
"""
for parameter_name in self._parameters:
vizualizer.plot_numerical_values(epoch, f'{parameter_name} Norm'.title(), self._data[parameter_name], None)
if self._comparison_plot:
for parameter_name in self._parameters:
tag_name = f'{parameter_name}{" Log" if self._log_scale else ""} Norm Comparison'.title()
vizualizer.plot_relations(epoch, tag_name, OrderedDict([(tag_name, self._comparison_data[parameter_name])]))
[docs]
def reset_epoch(self):
"""
Resets inner state.
Resets data computed during last epoch and resets preprocessors.
"""
self._preprocessor.reset()