Source code for monitorch.numerical.geometry_computation
import copy
import numpy as np
import torch
from .running_value import Accumulator, RunningMeanVar
[docs]
class GeometryComputation:
"""
An object used for geometry calculation.
Keeps track of norm (RMS) and optionally correlation between consecutive tensors.
Let :math:`X_1, ..., X_n` be sequence of tensors passed to :meth:`update`.
Keeps track of :math:`n_k = ||X||_2` or :math:`n_k' = \\frac{1}{\\sqrt{\\dim(X)}}||X||_2` (if `normalize=True`)
Optionally :math:`r_k = \\frac{X_{k-1} \\cdot X_k}{n_{k-1}n_k + \\epsilon}`
Parameters
----------
inplace : bool
Flag indicating use of :class:`RunningMeanVar`
normalize : bool
Flag indicating if norm of tensor should be computed as RMS
correlation : bool
Flag indicating computation of correlation between consecutive tensors, increases memory consumption by storing copy of previous tensor
eps : float
Constant used for numerical stability when computing correlation
"""
def __init__(self, inplace: bool, normalize: bool, correlation: bool, eps: float):
self.norm = RunningMeanVar() if inplace else list()
self.eps = eps
self.normalize = normalize
self.correlation = correlation
if correlation:
self.prev_norm = 1.0
self.prev_value = 0
self.product = RunningMeanVar() if inplace else list()
[docs]
def update(self, X: torch.Tensor):
"""
Performs an update step on norm and optionally correlation
Parameters
----------
X : torch.Tensor
Tensor to use for update
"""
new_norm = torch.linalg.vector_norm(X)
self.norm.append((new_norm.item() / np.sqrt(X.numel())) if self.normalize else new_norm.item())
if self.correlation:
new_product = torch.sum(self.prev_value * X) / (self.prev_norm * new_norm + self.eps)
self.product.append(new_product.item())
self.prev_norm = new_norm
self.prev_value = copy.deepcopy(X)
@property
def value(self) -> Accumulator | tuple[Accumulator, Accumulator]:
"""
Accumulated values (either list or :class:`RunningMeanVar`)
Returns
-------
tuple[`norms`, `products`]
if ``correlation=True``
`norms`
if ``correlation=False``
"""
if self.correlation:
return self.norm, self.product
return self.norm