Source code for analogvnn.nn.normalize.Clamp

from typing import Optional

import torch
from torch import Tensor

from analogvnn.backward.BackwardIdentity import BackwardIdentity
from analogvnn.nn.normalize.Normalize import Normalize

__all__ = ['Clamp', 'Clamp01']


[docs]class Clamp(Normalize, BackwardIdentity): """Implements the clamp normalization function with range [-1, 1].""" @staticmethod
[docs] def forward(x: Tensor): """Forward pass of the clamp normalization function with range [-1, 1]. Args: x (Tensor): the input tensor. Returns: Tensor: the output tensor. """ return torch.clamp(x, min=-1, max=1)
[docs] def backward(self, grad_output: Optional[Tensor]) -> Optional[Tensor]: """Backward pass of the clamp normalization function with range [-1, 1]. Args: grad_output (Optional[Tensor]): the gradient of the output tensor. Returns: Optional[Tensor]: the gradient of the input tensor. """ x = self.inputs grad = ((-1 <= x) * (x <= 1.)).type(torch.float) return grad_output * grad
[docs]class Clamp01(Normalize, BackwardIdentity): """Implements the clamp normalization function with range [0, 1].""" @staticmethod
[docs] def forward(x: Tensor): """Forward pass of the clamp normalization function with range [0, 1]. Args: x (Tensor): the input tensor. Returns: Tensor: the output tensor. """ return torch.clamp(x, min=0, max=1)
[docs] def backward(self, grad_output: Optional[Tensor]) -> Optional[Tensor]: """Backward pass of the clamp normalization function with range [0, 1]. Args: grad_output (Optional[Tensor]): the gradient of the output tensor. Returns: Optional[Tensor]: the gradient of the input tensor. """ x = self.inputs grad = ((0 <= x) * (x <= 1.)).type(torch.float) return grad_output * grad