analogvnn.backward.BackwardFunction#

Module Contents#

Classes#

BackwardFunction

The backward module that uses a function to compute the backward gradient.

class analogvnn.backward.BackwardFunction.BackwardFunction(backward_function: analogvnn.utils.common_types.TENSOR_CALLABLE, layer: torch.nn.Module = None)[source]#

Bases: analogvnn.backward.BackwardModule.BackwardModule, abc.ABC

The backward module that uses a function to compute the backward gradient.

Variables:

_backward_function (TENSOR_CALLABLE) – The function used to compute the backward gradient.

property backward_function: analogvnn.utils.common_types.TENSOR_CALLABLE[source]#

The function used to compute the backward gradient.

Returns:

The function used to compute the backward gradient.

Return type:

TENSOR_CALLABLE

_backward_function: analogvnn.utils.common_types.TENSOR_CALLABLE[source]#
set_backward_function(backward_function: analogvnn.utils.common_types.TENSOR_CALLABLE) BackwardFunction[source]#

Sets the function used to compute the backward gradient with.

Parameters:

backward_function (TENSOR_CALLABLE) – The function used to compute the backward gradient with.

Returns:

self.

Return type:

BackwardFunction

backward(*grad_output: torch.Tensor, **grad_output_kwarg: torch.Tensor) analogvnn.utils.common_types.TENSORS[source]#

Computes the backward gradient of inputs with respect to outputs using the backward function.

Parameters:
  • *grad_output (Tensor) – The gradients of the output of the layer.

  • **grad_output_kwarg (Tensor) – The gradients of the output of the layer.

Returns:

The gradients of the input of the layer.

Return type:

TENSORS

Raises:

NotImplementedError – If the backward function is not set.