Source code for monitorch.inspector.PyTorchInspector

from collections.abc import Callable, Iterable

from torch.nn import Module
from typing_extensions import Self

from monitorch.lens import AbstractLens
from monitorch.preprocessor import ExplicitCall
from monitorch.visualizer import AbstractVisualizer, _vizualizer_dict

from .inspector_state import InspectorState


[docs] class PyTorchInspector: """ One class to rule them all. ``PyTorchInspector`` is a class that manages interactions between lenses, visualizers and user defined module. To use inspector one needs to initialize the inspector and provide it the module to monitor. During initialization lenses from :mod:`monitorch.lens` must be provided. The only thing that is required during training is to call :meth:`tick_epoch` on the end of each epoch. Optionally one could push additional metrics using :meth:`push_metric` and :meth:`push_loss`. If visualizer is ``'matplotlib'``, then ``'show_fig()'`` must be called on :attr:`visualizer`, otherwise the plot will be drawn during training. Parameters ---------- lenses : list[AbstractLens] List of objects from :mod:`monitorch.lens`, used to collect and plot data. visualizer : str|AbstractVisualizer = 'print' Visualizer to draw plots, must be either a visualizer object from :mod:`monitorch.visualizer` or a string ``'matplotlib'``, ``'tensorboard'`` or ``'print'``. module : None|torch.nn.Module = None Optional neural network to examine, can be added later using :meth:`attach`. depth : int = -1 Depth to unfold neural net injection tree. For example ``depth=0`` returns the model itself, ``depth=1`` returns modules directly contained in ``module`` object. Default is ``depth=-1``, that is to unfold until leaf modules are reached. module_name_prefix : str = '.' Delimiter to separate names of parent and child modules. train_loss_str = 'train_loss' String to be used for training loss. non_train_loss_str = 'val_loss String to be used for validation/testing/development loss. is_active_fn : int | Callable[[int], bool] = 1 Function deciding if inspector is active (collects and visualizes data) for given epoch. Passed directly to `InspectorState`. Integer values correspond to function ``epoch % n == 0``, where `n` is passed value. Attributes ---------- lenses : list[AbstractLens] List of objects from :mod:`monitorch.lens`, used to collect and plot data. Exatcly the same object as the one provided during initialization. visualizer : AbstractVisualizer Visualizaer object that draws all plots. Can be hot-swapped. state : InspectorState State object representing inspectors inner state. Weak-referenced by gatherers. depth : int Depth to unfold module inclusion tree. module_name_prefix : str Delimiter to separate names of parent and child modules. Examples -------- Basic usage with ``'LossMetrics'``, ``'OutputActivation'`` and ``'ParameterGradientGeometry'`` may look something like this. >>> from monitorch.inspector import PyTorchInspector >>> from monitorch.lens import LossMetrics, OutputActivation, ParameterGradientGeometry >>> >>> loss_fn = nn.NLLLoss() >>> >>> inspector = PyTorchInspector( ... lenses = [ ... LossMetrics(loss_fn = loss_fn), ... OutputActivation(), ... ParameterGradientGeometry() ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... ... with torch.no_grad(): # outputs inside this block are not recorded ... for data, label in val_dataloader: ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... ... inspector.tick_epoch() # ticking the epoch >>> >>> inspector.visualizer.show_fig() """ def __init__( self, lenses: list[AbstractLens], *, visualizer: str | AbstractVisualizer = 'print', module: None | Module = None, depth: int = -1, module_name_prefix: str = '.', train_loss_str='train_loss', non_train_loss_str='val_loss', is_active_fn: int | Callable[[int], bool] = 1, ): self.lenses = lenses self._call_preprocessor = ExplicitCall(train_loss_str, non_train_loss_str) self.depth = depth self.module_name_prefix = module_name_prefix self.state: InspectorState = InspectorState(is_active_fn=is_active_fn) if isinstance(visualizer, str): if visualizer not in _vizualizer_dict: raise AttributeError(f'Unknown vizualizer, string defined vizualizer must be one of {list(_vizualizer_dict.keys())} ') self.visualizer = _vizualizer_dict[visualizer]() else: self.visualizer: AbstractVisualizer = visualizer for lens in self.lenses: lens.register_foreign_preprocessor(self._call_preprocessor, self.state) lens.introduce_tags(self.visualizer) if module is not None: self.attach(module)
[docs] def attach(self, module: Module) -> Self: """ Attaches inspector to a module. Unfolds inclusion module tree guided by ``depth`` set during initialization. Registers submodules onto every lens. Parameters ---------- module : torch.nn.Module Neural net to attach to. Returns ------- Self Builder pattern. """ if self.state.attached: self.detach() leaf_module_names, non_leaf_module_names = PyTorchInspector._traverse_module_inclusion_tree(module, self.depth, self.module_name_prefix) for module, name in leaf_module_names: for lens in self.lenses: lens.register_leaf_module(module, name, self.state) for module, name in non_leaf_module_names: for lens in self.lenses: lens.register_non_leaf_module(module, name, self.state) self.state.attached = True return self
[docs] def detach(self) -> Self: """ Detaches all lenses from modules. Returns ------- Self Builder pattern. """ assert self.state.attached, 'Inspector must be attached to module before detaching' self.state.counter = 0 self._call_preprocessor.reset() for lens in self.lenses: lens.detach_from_module() if hasattr(self.visualizer, 'reset_fig'): self.visualizer.reset_fig() self.state.attached = False return self
[docs] def push_metric(self, name: str, value: float, *, running: bool = True): """ Pushes metric, that can be accessed by :class:`monitorch.lens.LossMetrics`. Parameters ---------- name : str Name of the metric to save. value : float Metric's value. running : bool = True Flag indicating if metric should be saved in-place (True) or in-memory (False). """ if running: self._call_preprocessor.push_running(name, value) else: self._call_preprocessor.push_memory(name, value)
[docs] def push_loss(self, value: float, *, train: bool, running: bool = True): """ Pushes loss, that can be accessed by :class:`monitorch.lens.LossMetrics`. Parameters ---------- value : float Loss value. train : bool Flag indicating if it is training loss. running : bool = True Flag indicating if metric should be saved in-place (True) or in-memory (False). """ self._call_preprocessor.push_loss(value, train=train, running=running)
[docs] def tick_epoch(self, epoch: int | None = None): """ Ticks to postprocess data and draw plots. Parameters ---------- epoch : int|None = None Optional epoch counter, default ticks :attr:`state`, thus incrementing `counter`. """ if not self.state.is_active: if epoch is not None: self.state.counter = epoch else: self.state.tick() return for lens in self.lenses: lens.finalize_epoch() lens.vizualize(self.visualizer, self.state.counter) lens.reset_epoch() self._call_preprocessor.reset() if epoch is not None: self.state.counter = epoch else: self.state.tick()
tick = tick_epoch
[docs] def iter(self, iterable: Iterable) -> Iterable: dotick = False for x in iterable: if dotick: self.tick() else: dotick = True yield x self.tick()
[docs] def range(self, *args, **kwargs) -> Iterable: return self.iter(range(*args, **kwargs))
@staticmethod def _decide_prefix(prefix: str, grand_name: str): """Utility function for depth=0 name composition.""" return prefix if grand_name else '' @staticmethod def _traverse_module_inclusion_tree(module: Module, depth: int = -1, prefix: str = '.') -> tuple[list[tuple[Module, str]], list[tuple[Module, str]]]: """ A function to extract nodes at defined depth from module inclusion tree. If ``depth=-1`` calls :meth:`_module_deep_leaves`, otherwise recursively goes down the tree decreasing depth. Parameters ---------- module : torch.nn.Module Module which inclusion tree must be unfolded. depth : int = -1 Depth to which the module must be unfolded, default is -1, i.e., until leaf nodes. prefix : str = '.' Delimiter to separate names of parent and child modules. Returns ------- tuple[list[tuple[Module, str], list[tuple[Module, str]]] Lists of leaf (1st value) and non-leaf (2nd value) module object and their path name. """ assert depth >= -1, 'Depth of leaves must be non-negative or -1 (maximal depth)' if depth == -1: return PyTorchInspector._module_deep_leaves(module, prefix=prefix) if depth == 0: return [(module, '')], [] leaves = [] non_leaves = [] for name, child in module.named_children(): child_leaves, child_non_leaves = PyTorchInspector._traverse_module_inclusion_tree(child, depth - 1) leaves += [(module, name + PyTorchInspector._decide_prefix(prefix, grand_name) + grand_name) for module, grand_name in child_leaves] non_leaves += [(module, name + PyTorchInspector._decide_prefix(prefix, grand_name) + grand_name) for module, grand_name in child_non_leaves] if len(leaves) > 0: non_leaves.append((module, '')) return leaves, non_leaves @staticmethod def _module_deep_leaves(module: Module, prefix: str) -> tuple[list[tuple[Module, str]], list[tuple[Module, str]]]: """ A function to extract leaves from module inclusion tree. The function is recursive. Parameters ---------- module : torch.nn.Module Module which inclusion tree must be unfolded. prefix : str = '.' Delimiter to separate names of parent and child modules. Returns ------- tuple[list[tuple[Module, str], list[tuple[Module, str]]] Lists of leaf (1st value) and non-leaf (2nd value) module object and their path name. """ leaves = [] non_leaves = [] for name, child in module.named_children(): child_leaves, child_non_leaves = PyTorchInspector._module_deep_leaves(child, prefix=prefix) leaves += [(child_module, name + PyTorchInspector._decide_prefix(prefix, grand_name) + grand_name) for child_module, grand_name in child_leaves] non_leaves += [(child_module, name + PyTorchInspector._decide_prefix(prefix, grand_name) + grand_name) for child_module, grand_name in child_non_leaves] if len(leaves) == 0: leaves = [(module, '')] else: non_leaves.append((module, '')) return leaves, non_leaves