Source code for analogvnn.backward.BackwardUsingForward
from abc import ABC
from torch import Tensor
from analogvnn.backward.BackwardModule import BackwardModule
from analogvnn.utils.common_types import TENSORS
__all__ = ['BackwardUsingForward']
[docs]class BackwardUsingForward(BackwardModule, ABC):
"""The backward module that uses the forward function to compute the backward gradient."""
[docs] def backward(self, *grad_output: Tensor, **grad_output_kwarg: Tensor) -> TENSORS:
"""Computes the backward gradient of inputs with respect to outputs using the forward function.
Args:
*grad_output (Tensor): The gradients of the output of the layer.
**grad_output_kwarg (Tensor): The gradients of the output of the layer.
Returns:
TENSORS: The gradients of the input of the layer.
"""
return self._layer.forward(*grad_output, **grad_output_kwarg)