Source code for analogvnn.backward.BackwardIdentity

from abc import ABC

from torch import Tensor

from analogvnn.backward.BackwardModule import BackwardModule
from analogvnn.utils.common_types import TENSORS

__all__ = ['BackwardIdentity']


[docs]class BackwardIdentity(BackwardModule, ABC): """The backward module that returns the output gradients as the input gradients."""
[docs] def backward(self, *grad_output: Tensor, **grad_output_kwarg: Tensor) -> TENSORS: """Returns the output gradients as the input gradients. Args: *grad_output (Tensor): The gradients of the output of the layer. **grad_output_kwarg (Tensor): The gradients of the output of the layer. Returns: TENSORS: The gradients of the input of the layer. """ if len(grad_output) == 0: return None if len(grad_output) == 1: return grad_output[0] return grad_output