Source code for analogvnn.nn.activation.Identity

from typing import Optional

from torch import Tensor

from analogvnn.nn.activation.Activation import Activation

__all__ = ['Identity']


[docs]class Identity(Activation): """Implements the identity activation function. Attributes: name (str): the name of the activation function. """
[docs] name: Optional[str]
def __init__(self, name=None): """Initialize the identity activation function. Args: name (str): the name of the activation function. """ super().__init__() self.name = name
[docs] def extra_repr(self) -> str: """Extra __repr__ of the identity activation function. Returns: str: the extra representation of the identity activation function. """ if self.name is not None: return f'name={self.name}' else: return ''
@staticmethod
[docs] def forward(x: Tensor) -> Tensor: """Forward pass of the identity activation function. Args: x (Tensor): the input tensor. Returns: Tensor: the output tensor same as the input tensor. """ return x
[docs] def backward(self, grad_output: Optional[Tensor]) -> Optional[Tensor]: """Backward pass of the identity activation function. Args: grad_output (Optional[Tensor]): the gradient of the output tensor. Returns: Optional[Tensor]: the gradient of the input tensor same as the gradient of the output tensor. """ return grad_output