"""
A submodule implementing functions to get class of abstract ``torch.nn.Module``.
Examples
--------
>>> import torch.nn as nn
>>> from monitorch.lens.module_distinction import isactivation, isconv
>>> isactivation(nn.ReLU())
True
>>> isactivation(nn.Dropout())
False
>>> isconv(nn.BatchNorm1d(10))
False
>>> isconv(nn.Conv2d(1, 1, 1))
True
"""
from torch.nn.modules import (
CELU,
ELU,
GELU,
GLU,
SELU,
AlphaDropout,
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
Bilinear,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
Dropout,
Dropout1d,
Dropout2d,
Dropout3d,
FeatureAlphaDropout,
Hardshrink,
Hardsigmoid,
Hardswish,
Hardtanh,
Identity,
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LazyBatchNorm1d,
LazyBatchNorm2d,
LazyBatchNorm3d,
LazyConv1d,
LazyConv2d,
LazyConv3d,
LazyConvTranspose1d,
LazyConvTranspose2d,
LazyConvTranspose3d,
LazyInstanceNorm1d,
LazyInstanceNorm2d,
LazyInstanceNorm3d,
LazyLinear,
LeakyReLU,
Linear,
LogSigmoid,
LogSoftmax,
Mish,
Module,
# MultiheadAttention,
PReLU,
ReLU,
ReLU6,
RReLU,
Sigmoid,
SiLU,
Softmax,
Softmax2d,
Softmin,
Softplus,
Softshrink,
Softsign,
SyncBatchNorm,
Tanh,
Tanhshrink,
Threshold,
)
_DROPOUT = {
AlphaDropout,
Dropout,
Dropout1d,
Dropout2d,
Dropout3d,
FeatureAlphaDropout,
}
_ACTIVATION = {
CELU,
ELU,
GELU,
GLU,
Hardshrink,
Hardsigmoid,
Hardswish,
Hardtanh,
LeakyReLU,
LogSigmoid,
LogSoftmax,
Mish,
# MultiheadAttention,
PReLU,
ReLU,
ReLU6,
RReLU,
SELU,
Sigmoid,
SiLU,
Softmax,
Softmax2d,
Softmin,
Softplus,
Softshrink,
Softsign,
Tanh,
Tanhshrink,
Threshold,
}
_LINEAR = {
Bilinear,
Identity,
LazyLinear,
Linear,
}
_CONV = {
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
LazyConv1d,
LazyConv2d,
LazyConv3d,
LazyConvTranspose1d,
LazyConvTranspose2d,
LazyConvTranspose3d,
}
_NORMALIZATION = {
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
LazyBatchNorm1d,
LazyBatchNorm2d,
LazyBatchNorm3d,
SyncBatchNorm,
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LazyInstanceNorm1d,
LazyInstanceNorm2d,
LazyInstanceNorm3d,
}
[docs]
def isactivation(module: Module) -> bool:
"""
Checks if provided module is an activation function module.
Returns ``False`` for ``torch.nn.MultiheadAttention``.
Parameters
----------
module : torch.nn.Module
Module to be checked
"""
return module.__class__ in _ACTIVATION
[docs]
def isdropout(module: Module) -> bool:
"""
Checks if provided module is a dropout module.
Parameters
----------
module : torch.nn.Module
Module to be checked
"""
return module.__class__ in _DROPOUT
[docs]
def islinear(module: Module) -> bool:
"""
Checks if provided module is a linear non-convolution module.
Parameters
----------
module : torch.nn.Module
Module to be checked
"""
return module.__class__ in _LINEAR
[docs]
def isconv(module: Module) -> bool:
"""
Checks if provided module is a convolution module.
Parameters
----------
module : torch.nn.Module
Module to be checked
"""
return module.__class__ in _CONV
[docs]
def isnormalization(module: Module) -> bool:
"""
Checks if provided module is a normalization module.
Parameters
----------
module : torch.nn.Module
Module to be checked
"""
return module.__class__ in _NORMALIZATION