Source code for monitorch.numerical.geometry_computation
import copy
import numpy as np
import torch
from .running_value import Accumulator, RunningMeanVar, start_sync_rmv_or_error, finish_sync_rmv_or_error
[docs]
class GeometryComputation:
"""
An object used for geometry calculation.
Keeps track of norm (RMS) and optionally correlation between consecutive tensors.
Let :math:`X_1, ..., X_n` be sequence of tensors passed to :meth:`update`.
Keeps track of :math:`n_k = ||X||_2` or :math:`n_k' = \\frac{1}{\\sqrt{\\dim(X)}}||X||_2` (if `normalize=True`)
Optionally :math:`r_k = \\frac{X_{k-1} \\cdot X_k}{n_{k-1}n_k + \\epsilon}`
Parameters
----------
inplace : bool
Flag indicating use of :class:`RunningMeanVar`
normalize : bool
Flag indicating if norm of tensor should be computed as RMS
correlation : bool
Flag indicating computation of correlation between consecutive tensors, increases memory consumption by storing copy of previous tensor
eps : float
Constant used for numerical stability when computing correlation
"""
def __init__(self, inplace: bool, normalize: bool, correlation: bool, eps: float):
self.norm = RunningMeanVar() if inplace else list()
self.eps = eps
self.normalize = normalize
self.correlation = correlation
if correlation:
self.prev_norm = 1.0
self.prev_value = 0
self.product = RunningMeanVar() if inplace else list()
[docs]
def update(self, X: torch.Tensor):
"""
Performs an update step on norm and optionally correlation
Parameters
----------
X : torch.Tensor
Tensor to use for update
"""
new_norm = torch.linalg.vector_norm(X)
self.norm.append((new_norm.item() / np.sqrt(X.numel())) if self.normalize else new_norm.item())
if self.correlation:
new_product = torch.sum(self.prev_value * X) / (self.prev_norm * new_norm + self.eps)
self.product.append(new_product.item())
self.prev_norm = new_norm
self.prev_value = copy.deepcopy(X)
@property
def value(self) -> Accumulator | tuple[Accumulator, Accumulator]:
"""
Accumulated values (either list or :class:`RunningMeanVar`)
Returns
-------
tuple[`norms`, `products`]
if ``correlation=True``
`norms`
if ``correlation=False``
"""
if self.correlation:
return self.norm, self.product
return self.norm
[docs]
def start_sync(self, dst_rank: int) -> None:
"""
Starts synchronizing the world to gather data at dst_rank or raises error if inplace.
Parameters
----------
dst_rank : int
Rank of gathering destination.
"""
start_sync_rmv_or_error(self.norm, dst_rank)
if self.correlation:
start_sync_rmv_or_error(self.product, dst_rank)
[docs]
def finish_sync(self) -> None:
"""
Finishes synchronizing the world to gather data at dst_rank or raises error if inplace.
Parameters
----------
dst_rank : int
Rank of gathering destination.
"""
finish_sync_rmv_or_error(self.norm)
if self.correlation:
finish_sync_rmv_or_error(self.product)