Source code for analogvnn.nn.normalize.LPNorm

import torch
from torch import nn, Tensor

from analogvnn.backward.BackwardIdentity import BackwardIdentity
from analogvnn.nn.normalize.Normalize import Normalize

__all__ = ['LPNorm', 'LPNormW', 'L1Norm', 'L2Norm', 'L1NormW', 'L2NormW', 'L1NormM', 'L2NormM', 'L1NormWM', 'L2NormWM']


[docs]class LPNorm(Normalize, BackwardIdentity): """Implements the row-wise Lp normalization function. Attributes: p (int): the pth power of the Lp norm. make_max_1 (bool): if True, the maximum absolute value of the output tensor will be 1. """
[docs] __constants__ = ['p', 'make_max_1']
[docs] p: nn.Parameter
[docs] make_max_1: nn.Parameter
def __init__(self, p: int, make_max_1=False): """Initializes the row-wise Lp normalization function. Args: p (int): the pth power of the Lp norm. make_max_1 (bool): if True, the maximum absolute value of the output tensor will be 1. """ super().__init__() self.p = nn.Parameter(torch.tensor(p), requires_grad=False) self.make_max_1 = nn.Parameter(torch.tensor(make_max_1), requires_grad=False)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass of row-wise Lp normalization function. Args: x (Tensor): the input tensor. Returns: Tensor: the output tensor. """ norm = x if len(x.shape) > 1: norm = torch.flatten(norm, start_dim=1) norm = torch.norm(norm, self.p, -1) norm = torch.clamp(norm, min=1e-4) x = torch.div(x.T, norm).T if self.make_max_1: x = torch.div(x, torch.max(torch.abs(x))) return x
[docs]class LPNormW(LPNorm): """Implements the whole matrix Lp normalization function."""
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass of whole matrix Lp normalization function. Args: x (Tensor): the input tensor. Returns: Tensor: the output tensor. """ norm = torch.norm(x, self.p) norm = torch.clamp(norm, min=1e-4) x = torch.div(x, norm) if self.make_max_1: x = torch.div(x, torch.max(torch.abs(x))) return x
[docs]class L1Norm(LPNorm): """Implements the row-wise L1 normalization function.""" def __init__(self): """Initializes the row-wise L1 normalization function.""" super().__init__(p=1, make_max_1=False)
[docs]class L2Norm(LPNorm): """Implements the row-wise L2 normalization function.""" def __init__(self): """Initializes the row-wise L2 normalization function.""" super().__init__(p=2, make_max_1=False)
[docs]class L1NormW(LPNormW): """Implements the whole matrix L1 normalization function.""" def __init__(self): """Initializes the whole matrix L1 normalization function.""" super().__init__(p=1, make_max_1=False)
[docs]class L2NormW(LPNormW): """Implements the whole matrix L2 normalization function.""" def __init__(self): """Initializes the whole matrix L2 normalization function.""" super().__init__(p=2, make_max_1=False)
[docs]class L1NormM(LPNorm): """Implements the row-wise L1 normalization function with maximum absolute value of 1.""" def __init__(self): """Initializes the row-wise L1 normalization function with maximum absolute value of 1.""" super().__init__(p=1, make_max_1=True)
[docs]class L2NormM(LPNorm): """Implements the row-wise L2 normalization function with maximum absolute value of 1.""" def __init__(self): """Initializes the row-wise L2 normalization function with maximum absolute value of 1.""" super().__init__(p=2, make_max_1=True)
[docs]class L1NormWM(LPNormW): """Implements the whole matrix L1 normalization function with maximum absolute value of 1.""" def __init__(self): """Initializes the whole matrix L1 normalization function with maximum absolute value of 1.""" super().__init__(p=1, make_max_1=True)
[docs]class L2NormWM(LPNormW): """Implements the whole matrix L2 normalization function with maximum absolute value of 1.""" def __init__(self): """Initializes the whole matrix L2 normalization function with maximum absolute value of 1.""" super().__init__(p=2, make_max_1=True)