Source code for monitorch.lens.parameter_update_geometry

from collections import OrderedDict
from collections.abc import Iterable

from torch.nn import Module
from torch.optim import Optimizer

from monitorch.gatherer import AbstractGatherer, CallParameterGatherer, OptimizerStepParameterGatherer
from monitorch.numerical import extract_point, extract_range, parse_range_name
from monitorch.preprocessor import AbstractPreprocessor, ParameterDifferenceGeometry
from monitorch.visualizer import AbstractVisualizer, TagAttributes, TagType

from .abstract_lens import AbstractLens


[docs] class ParameterUpdateGeometry(AbstractLens): """ Lens to examine geometry of parameter updates. Computes L2-norm or root-mean-square of parameter updates on every optimizer step call. Optionally computes correlation between parameter updates from two consecutive steps. Computing correlation requires parameter update from both epochs, hence the update will be saved after the computation is finished. It drives space consumption linearly by size of studied parameters. Parameters ---------- optimizer : torch.optim.Optimizer | None Optimizer, its step calls will be used for data collecection. If `None` is provided the lens must explicitly called via :meth:`inspect_update`. inplace : bool = True Flag indicating if computation should be done in-place or in-memory. normalize_by_size : bool = False Flag indicating if output norm should be divided by root of number of elements, thus obtaining RMS of output. log_scale : bool = False Flag indicating if logarithmic scale should be used. compute_correlation : bool = True Flag indicating if correlation between updates from consecutive optimizer steps should be computed. parameters : str|Iterable[str] = ('weight', 'bias') Parameters which updates will be studied. line_aggregation : str|Iterable[str] = 'mean' Aggregation method for lines in plots. range_aggregation : str|Iterable[str]|None = ('std', 'min-max') Aggregation method for bands in plots. Examples -------- Default usage is shown below. >>> optimizer = torch.optim.AdamW(mynet.parameters()) >>> inspector = pytorchinspector( ... lenses = [ ... ParameterUpdateGeometry(optimizer), ... ], ... module = mynet, ... visualizer='matplotlib' ... ) >>> >>> for epoch in range(N_EPOCHS): ... for data, label in train_dataloader: ... optimizer.zero_grad() ... prediction = mynet(data) ... loss = loss_fn(prediction, label) ... loss.backward() ... optimizer.step() ... ... inspector.tick_epoch() >>> >>> inspector.visualizer.show_fig() """ def __init__( self, optimizer: Optimizer | None, inplace: bool = True, normalize_by_size: bool = False, log_scale: bool = False, compute_correlation: bool = True, parameters: str | Iterable[str] = ('weight', 'bias'), line_aggregation: str | Iterable[str] = 'mean', range_aggregation: str | Iterable[str] | None = ('std', 'min-max'), ): self.optimizer: Optimizer | None = optimizer self._compute_correlation = compute_correlation self._preprocessors = OrderedDict([(parameter, ParameterDifferenceGeometry(inplace=inplace, normalize=normalize_by_size, correlation=compute_correlation)) for parameter in parameters]) self._gatherers: list[AbstractGatherer] = [] self._line_data: dict[str, OrderedDict[str, dict[str, float]]] = {} self._range_data: dict[str, OrderedDict[str, dict[tuple[str, str], tuple[float, float]]]] = {} if self._compute_correlation: self._line_correlation_data: dict[str, OrderedDict[str, dict[str, float]]] = {} self._range_correlation_data: dict[str, OrderedDict[str, dict[tuple[str, str], tuple[float, float]]]] = {} self._log_scale = log_scale self._line_aggregation: Iterable[str] = [line_aggregation] if isinstance(line_aggregation, str) else line_aggregation self._range_aggregation: Iterable[str] if isinstance(range_aggregation, str): self._range_aggregation = [range_aggregation] elif range_aggregation is None: self._range_aggregation = [] else: self._range_aggregation = range_aggregation
[docs] def register_leaf_module(self, module: Module, module_name: str, inspector_state): """ Registers (or ignores) module. Registers any module that has all of the parameters listed during initialization. Parameters ---------- module : torch.nn.Module The module object to hook gatherers onto. module_name : str Name of the module, module's information will be passed to visaulizer under this name. """ self._register_module(module, module_name, inspector_state)
[docs] def register_non_leaf_module(self, module: Module, module_name: str, inspector_state): """ Registers (or ignores) module. Registers any module that has all of the parameters listed during initialization. Parameters ---------- module : torch.nn.Module The module object to hook gatherers onto. module_name : str Name of the module, module's information will be passed to visaulizer under this name. """ self._register_module(module, module_name, inspector_state)
def _register_module(self, module: Module, module_name: str, inspector_state): """ Generic function called from :meth:`register_non_leaf_module` and :meth:`register_leaf_module` Parameters ---------- module : torch.nn.Module The module object to hook gatherers onto. module_name : str Name of the module, module's information will be passed to visaulizer under this name. """ if not all(hasattr(module, parameter_name) and getattr(module, parameter_name) is not None for parameter_name in self._preprocessors): return for parameter, preprocessor in self._preprocessors.items(): if self.optimizer is not None: pgg = OptimizerStepParameterGatherer(optimizer=self.optimizer, parameter=parameter, module=module, preprocessors=[preprocessor], name=module_name, inspector_state=inspector_state) else: pgg = CallParameterGatherer(parameter=parameter, module=module, preprocessors=[preprocessor], name=module_name, inspector_state=inspector_state) self._gatherers.append(pgg)
[docs] def inspect_update(self): """ Utility function to collect updates. Used primarily to collect data if no optimizer is provided. Calls every gatherer without any arguments with no regard to underline type. """ for gatherer in self._gatherers: gatherer()
[docs] def detach_from_module(self): """ Detaches lens from module. Detaches from optimizer setting it to ``None`` Detaches gatherers and resets inner state. """ for gatherer in self._gatherers: gatherer.detach() self._gatherers = [] self.optimizer = None self._line_data: dict[str, OrderedDict[str, dict[str, float]]] = {} self._range_data: dict[str, OrderedDict[str, dict[tuple[str, str], tuple[float, float]]]] = {} if self._compute_correlation: self._line_correlation_data: dict[str, OrderedDict[str, dict[str, float]]] = {} self._range_correlation_data: dict[str, OrderedDict[str, dict[tuple[str, str], tuple[float, float]]]] = {}
[docs] def register_foreign_preprocessor(self, ext_ppr: AbstractPreprocessor, inspector_state): """Does not interact with foreign preprocessor.""" pass
[docs] def introduce_tags(self, vizualizer: AbstractVisualizer): """ Introduces lens's plots to visualizer. For every parameter listed during initialization creates a small numerical plot '#PARAMETER_NAME Update Norm' optionally creates a big comparison plot '#PARAMETER_NAME Update Correlation'. Parameters ---------- visualzier : AbstractVisualizer A visualizer object to pass tag attributes to. """ for parameter_name in self._preprocessors: vizualizer.register_tags( f'{parameter_name} Update Norm'.title(), TagAttributes(logy=self._log_scale, big_plot=False, annotate=True, type=TagType.NUMERICAL), ) if self._compute_correlation: vizualizer.register_tags( f'{parameter_name} Update Correlation'.title(), TagAttributes(logy=False, big_plot=False, annotate=True, type=TagType.NUMERICAL, ylim=(1, -1)), )
[docs] def finalize_epoch(self): """ Finaizes computations done through epoch. Aggregates parameter updates' norms and optionally correlation according to ``line_aggregation`` and ``range_aggregation``. """ for parameter_name, preprocessor in self._preprocessors.items(): line_norm_tag_dict: OrderedDict[str, dict[str, float]] = self._line_data.setdefault(parameter_name, OrderedDict()) range_norm_tag_dict: OrderedDict[str, dict[tuple[str, str], tuple[float, float]]] = self._range_data.setdefault(parameter_name, OrderedDict()) line_prod_tag_dict: OrderedDict[str, dict[str, float]] range_prod_tag_dict: OrderedDict[str, dict[tuple[str, str], tuple[float, float]]] if self._compute_correlation: line_prod_tag_dict = self._line_correlation_data.setdefault(parameter_name, OrderedDict()) range_prod_tag_dict = self._range_correlation_data.setdefault(parameter_name, OrderedDict()) for module_name, value in preprocessor.value.items(): line_norm_dict: dict[str, float] = line_norm_tag_dict.setdefault(module_name, {}) range_norm_dict: dict[tuple[str, str], tuple[float, float]] = range_norm_tag_dict.setdefault(module_name, {}) line_prod_dict: dict[str, float] range_prod_dict: dict[tuple[str, str], tuple[float, float]] if self._compute_correlation: line_prod_dict = line_prod_tag_dict.setdefault(module_name, {}) range_prod_dict = range_prod_tag_dict.setdefault(module_name, {}) if self._compute_correlation: norm, prod = value for method in self._line_aggregation: line_norm_dict[method] = extract_point(norm, method) line_prod_dict[method] = extract_point(prod, method) for method in self._range_aggregation: range_norm_dict[parse_range_name(method)] = extract_range(norm, method) range_prod_dict[parse_range_name(method)] = extract_range(prod, method) else: for method in self._line_aggregation: line_norm_dict[method] = extract_point(value, method) for method in self._range_aggregation: range_norm_dict[parse_range_name(method)] = extract_range(value, method) self._line_data[parameter_name] = OrderedDict(reversed(line_norm_tag_dict.items())) self._range_data[parameter_name] = OrderedDict(reversed(range_norm_tag_dict.items())) if self._compute_correlation: self._line_correlation_data[parameter_name] = OrderedDict(reversed(line_prod_tag_dict.items())) self._range_correlation_data[parameter_name] = OrderedDict(reversed(range_prod_tag_dict.items()))
[docs] def vizualize(self, vizualizer: AbstractVisualizer, epoch: int): """ Passes computed data to visualizer. Passes dictionary of per layer data to '#PARAMETER_NAME Update Norm', the dictionary may look something like this. :: OrderedDict([ ('lin1', {'mean' : 0.8}, {'min' : 0.2, 'max' : 0.9}), ('lin2', {'mean' : 0.6}, {'min' : 0.3, 'max' : 0.7}), ]) Update correlation dictionary looks the same. Parameters ---------- visualizer : AbstractVisualizer The visualizer object responsbile for drawing plots. epoch : int Computation's epoch number. """ for parameter_name in self._preprocessors: vizualizer.plot_numerical_values(epoch, f'{parameter_name} Update Norm'.title(), self._line_data[parameter_name], self._range_data[parameter_name]) if self._compute_correlation: vizualizer.plot_numerical_values(epoch, f'{parameter_name} Update Correlation'.title(), self._line_correlation_data[parameter_name], self._range_correlation_data[parameter_name])
[docs] def reset_epoch(self): """ Resets inner state. Resets data computed during last epoch and resets preprocessors. """ for preprocessor in self._preprocessors.values(): preprocessor.reset()