Source code for monitorch.gatherer.backward_gatherer
from monitorch.preprocessor import AbstractBackwardPreprocessor
from .abstract_gatherer import AbstractGatherer
[docs]
class BackwardGatherer(AbstractGatherer):
"""
Object responsible for collecting data from `torch.nn.Module.register_full_backward_hook`.
Registers self to module provided in construction as a backward hook,
on call hands over data and module's name to preprocessors.
Parameters
----------
module : torch.nn.Module
Module to hook onto.
preprocessors : list[:class:`AbstractBackwardPreprocessor`]
List of preprocessors to hand over data when PyTorch calls the hook.
name : str
Name of module to hand over to preprocessors.
"""
def __init__(self, module, preprocessors: list[AbstractBackwardPreprocessor], name: str, inspector_state):
super().__init__(inspector_state)
self._preprocessors = preprocessors
self._name = name
self._handle = module.register_full_backward_hook(self)
[docs]
def detach(self) -> None:
"""
See base class.
"""
super().detach()
self._handle.remove()
@AbstractGatherer.requires_active_inspector_state
def __call__(self, module, grad_inp, grad_out) -> None:
for preprocessor in self._preprocessors:
preprocessor.process_bw(self._name, module, grad_inp, grad_out)