Source code for monitorch.preprocessor.abstract.abstract_backward_preprocessor

"""
Base class for all backward pass preprocessors
"""

from abc import abstractmethod

from .abstract_preprocessor import AbstractPreprocessor


[docs] class AbstractBackwardPreprocessor(AbstractPreprocessor): """ Base class for all preprocessors that aggregate data obtain from backward pass. Subclasses of ``AbstractBackwardPreprocessor`` process gradients with respect to inputs or outputs of module. """
[docs] @abstractmethod def process_bw(self, name: str, module, grad_input, grad_output): """ Processes backward pass data. Parameters ---------- name : str Name of the module, its data is processed module : torch.nn.Module Module object from which the data is processed grad_input : torch.Tensor Gradients with respect to input of module. grad_output : torch.Tensor Gradients with respect to output of module. """ pass