Source code for monitorch.lens.output_gradient_geometry

from collections import OrderedDict
from collections.abc import Iterable

from torch.nn import Module

from monitorch.gatherer import BackwardGatherer
from monitorch.numerical import extract_point, extract_range, parse_range_name
from monitorch.preprocessor import AbstractPreprocessor
from monitorch.preprocessor import OutputGradientGeometry as OutputGradientGeometryPreprocessor
from monitorch.visualizer import AbstractVisualizer, TagAttributes, TagType

from .abstract_lens import AbstractLens
from .module_distinction import isactivation


[docs] class OutputGradientGeometry(AbstractLens): """ Lens to examine geometry of gradients with respect to layer outputs. Computes L2-norm or root-mean-square of gradients on every backward pass through layer. Optionally computes correlation between gradients from two consecutive backward passes. Computing correlation requires gradients from both epochs, hence the gradient will be saved after the computation is finished. It drives space consumption linearly by size of studied outputs. Parameters ---------- inplace : bool = True Flag indicating if computation should be done in-place or in-memory. normalize_by_size : bool = False Flag indicating if output norm should be divided by root of number of elements, thus obtaining RMS of output. log_scale : bool = False Flag indicating if logarithmic scale should be used. compute_correlation : bool = True Flag indicating if correlation between gradients from consecutive backward passes should be computed. skip_activation : bool = True Flag indicating if lens should NOT register activation layers. line_aggregation : str|Iterable[str] = 'mean' Aggregation method for lines in plots. range_aggregation : str|Iterable[str]|None = ('std', 'min-max') Aggregation method for bands in plots. Examples -------- Default usage is shown below. >>> inspector = PyTorchInspector( ... lenses = [ ... OutputGradientGeometry(), ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... ... inspector.tick_epoch() >>> >>> inspector.visualizer.show_fig() """ _SMALL_NORM_TAG_NAME = 'Output Gradient Norm' _SMALL_PROD_TAG_NAME = 'Output Gradient Correlation' def __init__( self, inplace: bool = True, normalize_by_size: bool = False, log_scale: bool = False, compute_correlation: bool = True, skip_activation: bool = True, line_aggregation: str | Iterable[str] = 'mean', range_aggregation: str | Iterable[str] | None = ('std', 'min-max'), ): self._compute_correlation = compute_correlation self._skip_activation = skip_activation self._preprocessor = OutputGradientGeometryPreprocessor(inplace=inplace, normalize=normalize_by_size, correlation=compute_correlation) self._gatherers = [] self._line_data: OrderedDict[str, dict[str, float]] = OrderedDict() self._range_data: OrderedDict[str, dict[tuple[str, str], tuple[float, float]]] = OrderedDict() if self._compute_correlation: self._line_correlation_data: OrderedDict[str, dict[str, float]] = OrderedDict() self._range_correlation_data: OrderedDict[str, dict[tuple[str, str], tuple[float, float]]] = OrderedDict() self._log_scale = log_scale self._line_aggregation: Iterable[str] = [line_aggregation] if isinstance(line_aggregation, str) else line_aggregation self._range_aggregation: Iterable[str] if isinstance(range_aggregation, str): self._range_aggregation = [range_aggregation] elif range_aggregation is None: self._range_aggregation = [] else: self._range_aggregation = range_aggregation
[docs] def register_leaf_module(self, module: Module, module_name: str, inspector_state): """ Registers (or ignores) module. If ``skip_activation`` is ``True``, then does not register activation modules, otherwise registers any module. Parameters ---------- module : torch.nn.Module The module object to hook gatherers onto. module_name : str Name of the module, module's information will be passed to visaulizer under this name. """ if self._skip_activation and isactivation(module): return bg = BackwardGatherer(module, [self._preprocessor], module_name, inspector_state=inspector_state) self._gatherers.append(bg)
[docs] def detach_from_module(self): """ Detaches lens from module. Detaches gatherers and resets inner state. """ for gatherer in self._gatherers: gatherer.detach() self._gatherers = [] self._line_data: OrderedDict[str, dict[str, float]] = OrderedDict() self._range_data: OrderedDict[str, dict[tuple[str, str], tuple[float, float]]] = OrderedDict() if self._compute_correlation: self._line_correlation_data: OrderedDict[str, dict[str, float]] = OrderedDict() self._range_correlation_data: OrderedDict[str, dict[tuple[str, str], tuple[float, float]]] = OrderedDict()
[docs] def register_foreign_preprocessor(self, ext_ppr: AbstractPreprocessor, inspector_state): """Does not interact with foreign preprocessor.""" pass
[docs] def introduce_tags(self, vizualizer: AbstractVisualizer): """ Introduces lens's plots to visualizer. Registers a small numerical plot 'Output Gradient Norm' and optionally registers a small numerical plot 'Output Gradient Correlation'. Parameters ---------- visualzier : AbstractVisualizer A visualizer object to pass tag attributes to. """ vizualizer.register_tags( OutputGradientGeometry._SMALL_NORM_TAG_NAME, TagAttributes(logy=self._log_scale, big_plot=False, annotate=True, type=TagType.NUMERICAL), ) if self._compute_correlation: vizualizer.register_tags( OutputGradientGeometry._SMALL_PROD_TAG_NAME, TagAttributes(logy=False, big_plot=False, annotate=True, type=TagType.NUMERICAL), )
[docs] def finalize_epoch(self): """ Finaizes computations done through epoch. Aggregates output gradient norms and optionally correlation according to ``line_aggregation`` and ``range_aggregation``. """ for module_name, value in self._preprocessor.value.items(): line_norm_dict: dict[str, float] = self._line_data.setdefault(module_name, {}) range_norm_dict: dict[tuple[str, str], tuple[float, float]] = self._range_data.setdefault(module_name, {}) line_prod_dict: dict[str, float] range_prod_dict: dict[tuple[str, str], tuple[float, float]] if self._compute_correlation: line_prod_dict = self._line_correlation_data.setdefault(module_name, {}) range_prod_dict = self._range_correlation_data.setdefault(module_name, {}) if self._compute_correlation: norm, prod = value for method in self._line_aggregation: line_norm_dict[method] = extract_point(norm, method) line_prod_dict[method] = extract_point(prod, method) for method in self._range_aggregation: range_norm_dict[parse_range_name(method)] = extract_range(norm, method) range_prod_dict[parse_range_name(method)] = extract_range(prod, method) else: for method in self._line_aggregation: line_norm_dict[method] = extract_point(value, method) for method in self._range_aggregation: range_norm_dict[parse_range_name(method)] = extract_range(value, method) self._line_data = OrderedDict(reversed(self._line_data.items())) self._range_data = OrderedDict(reversed(self._range_data.items())) if self._compute_correlation: self._line_correlation_data = OrderedDict(reversed(self._line_correlation_data.items())) self._range_correlation_data = OrderedDict(reversed(self._range_correlation_data.items()))
[docs] def vizualize(self, vizualizer: AbstractVisualizer, epoch: int): """ Passes computed data to visualizer. Passes dictionary of per layer data to 'Output Gradient Norm', the dictionary may look something like this. :: OrderedDict([ ('lin1', {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}), ('relu1', {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}), ]) Gradient correlation dictionary looks the same. Parameters ---------- visualizer : AbstractVisualizer The visualizer object responsbile for drawing plots. epoch : int Computation's epoch number. """ vizualizer.plot_numerical_values(epoch, OutputGradientGeometry._SMALL_NORM_TAG_NAME, self._line_data, self._range_data) if self._compute_correlation: vizualizer.plot_numerical_values(epoch, OutputGradientGeometry._SMALL_PROD_TAG_NAME, self._line_correlation_data, self._range_correlation_data)
[docs] def reset_epoch(self): """ Resets inner state. Resets data computed during last epoch and resets preprocessors. """ self._preprocessor.reset()