Source code for monitorch.gatherer.parameter_gradient_gatherer

from monitorch.preprocessor import AbstractTensorPreprocessor

from .abstract_gatherer import AbstractGatherer


[docs] class ParameterGradientGatherer(AbstractGatherer): """ Class to collect gradients from attributes of module. Object of ``ParameterGradientGatherer`` gatherer is a stateful callback registered onto ``torch.Tensor`` using ``register_post_accumulate_grad_hook``. On call hands over data to preprocessors. Parameters ---------- parameter : str Name of learnable parameter in module to gather data from. module : torch.nn.Module Module from which the learnable parameter is obtained. The data will be collected from that learnable parameter. preprocessors : list[:class:`AbstractTensorPreprocessor`] Preprocessors that will aggregate data. name : str Name of the module. """ def __init__(self, parameter: str, module, preprocessors: list[AbstractTensorPreprocessor], name: str, inspector_state): super().__init__(inspector_state) self._preprocessors = preprocessors self._name = name self._handle = getattr(module, parameter).register_post_accumulate_grad_hook(self) @AbstractGatherer.requires_active_inspector_state def __call__(self, parameter): for preprocessor in self._preprocessors: preprocessor.process_tensor(self._name, parameter.grad)
[docs] def detach(self) -> None: """ See base class """ super().detach() self._handle.remove()