Source code for monitorch.preprocessor.parameter.parameter_norm

from collections import OrderedDict
from typing import Any

from monitorch.numerical import GeometryComputation
from monitorch.preprocessor.abstract.abstract_module_preprocessor import AbstractModulePreprocessor


[docs] class ParameterNorm(AbstractModulePreprocessor): """ Preprocessor computing norms of parameters. Computes norm of parameters listed in :attr:`attrs_` for every module that is being passed to process module. Parameters ---------- attrs : list[str] List of attributes for which norm will be computed. normalize : bool Flag indicating whether norm should be normalized by tensor size. If true computes RMS of tensor values, L2-norm otherwise. inplace : bool Flag indicating if :class:`RunningMeanVar` or ``list`` will be used. Attributes ---------- attrs_ : list[str] List of attributes to compute norm for. """ def __init__(self, attrs: list[str], normalize: bool, inplace: bool): self._gc_kwargs = dict(normalize=normalize, inplace=inplace, correlation=False, eps=0.0) self.attrs_ = attrs self._value: OrderedDict[str, dict[str, GeometryComputation]] = OrderedDict()
[docs] def process_module(self, name: str, module): """ Computes norms of all :attr:`attrs_`. Uses ``torch.linalg.vector_norm`` to compute L2-norm of module's attributes. If ``normalize`` is true, divides norm by a square root of number of elements in attributes. """ d = self._value.setdefault(name, {}) for attr in self.attrs_: param = getattr(module, attr) gc = d.setdefault(attr, GeometryComputation(**self._gc_kwargs)) gc.update(param)
@property def value(self) -> OrderedDict[str, Any]: """ See base class """ return OrderedDict([(name, {attr: d[attr].value for attr in self.attrs_}) for name, d in self._value.items()])
[docs] def reset(self) -> None: """ See base class """ self._value = OrderedDict()