Source code for analogvnn.nn.activation.Gaussian
import math
from typing import Optional
import torch
from torch import Tensor
from analogvnn.nn.activation.Activation import Activation
__all__ = ['Gaussian', 'GeLU']
[docs]class Gaussian(Activation):
"""Implements the Gaussian activation function."""
@staticmethod
[docs] def forward(x: Tensor) -> Tensor:
"""Forward pass of the Gaussian activation function.
Args:
x (Tensor): the input tensor.
Returns:
Tensor: the output tensor.
"""
return torch.exp(-torch.pow(x, 2))
[docs] def backward(self, grad_output: Optional[Tensor]) -> Optional[Tensor]:
"""Backward pass of the Gaussian activation function.
Args:
grad_output (Optional[Tensor]): the gradient of the output tensor.
Returns:
Optional[Tensor]: the gradient of the input tensor.
"""
x = self.inputs
grad = -2 * x * torch.exp(-torch.pow(x, 2))
return grad_output * grad
[docs]class GeLU(Activation):
"""Implements the Gaussian error linear unit (GeLU) activation function."""
@staticmethod
[docs] def forward(x: Tensor) -> Tensor:
"""Forward pass of the Gaussian error linear unit (GeLU) activation function.
Args:
x (Tensor): the input tensor.
Returns:
Tensor: the output tensor.
"""
return (1 / 2) * x * (1 + torch.erf(x / math.sqrt(2)))
[docs] def backward(self, grad_output: Optional[Tensor]) -> Optional[Tensor]:
"""Backward pass of the Gaussian error linear unit (GeLU) activation function.
Args:
grad_output (Optional[Tensor]): the gradient of the output tensor.
Returns:
Optional[Tensor]: the gradient of the input tensor.
"""
x = self.inputs
grad = (1 / 2) * (
(1 + torch.erf(x / math.sqrt(2))) + x * (math.sqrt(2 * math.pi) * torch.exp(-torch.pow(x, 2) / 2))
)
return grad_output * grad