Source code for monitorch.preprocessor.parameter.parameter_difference_geometry
from collections import OrderedDict
from copy import deepcopy
from typing import Any
from torch import Tensor, no_grad
from monitorch.numerical import GeometryComputation
from monitorch.preprocessor.abstract.abstract_tensor_preprocessor import AbstractTensorPreprocessor
[docs]
class ParameterDifferenceGeometry(AbstractTensorPreprocessor):
"""
Preprocessor to keep track of parameters evolution with respect to preprocessor calls by inspecting it updates.
Main usage is to inspect optimizer update step behaviour.
Computes (normalized) L2 norm of parameter updates.
Optionally computes correlation between consecutive parameter differences for further investigation,
normalized to fit into [-1, 1] range.
Parameters
----------
correlation : bool
Indicator if correlation must be computed.
normalize : bool
Indicator if gradient norm should be divided by square root of number of elements.
inplace : bool
Flag indicating whether to collect data inplace using :class:`RunningMeanVar` or to stack them into a list.
"""
def __init__(self, correlation: bool, normalize: bool, inplace: bool, eps: float = 1e-8):
self._gc_kwargs: dict[str, bool] = dict(
normalize=normalize,
correlation=correlation,
inplace=inplace,
)
self._eps = eps
self._value: OrderedDict[str, GeometryComputation] = OrderedDict()
self._prev_param: dict[str, Tensor] = {}
[docs]
@no_grad
def process_tensor(self, name: str, param: Tensor) -> None:
"""
Computes (normalized) L2 norm and optionally correlation with previous difference.
Parameters
----------
name : str
Name of source of parameter.
param : torch.Tensor
Parameter tensor to be processed.
"""
if name in self._prev_param:
diff = param - self._prev_param[name]
geometry_computation = self._value.setdefault(name, GeometryComputation(**self._gc_kwargs, eps=self._eps))
geometry_computation.update(diff)
self._prev_param[name] = deepcopy(param)
@property
def value(self) -> dict[str, Any]:
"""See base class."""
return {k: gc.value for k, gc in self._value.items()}
[docs]
def reset(self) -> None:
"""See base class."""
self._value = OrderedDict()