Source code for analogvnn.graph.ForwardGraph

from __future__ import annotations

from typing import Dict

import torch
from torch import Tensor

from analogvnn.graph.AcyclicDirectedGraph import AcyclicDirectedGraph
from analogvnn.graph.ArgsKwargs import ArgsKwargs, InputOutput, ArgsKwargsOutput
from analogvnn.graph.GraphEnum import GraphEnum
from analogvnn.utils.common_types import TENSORS

__all__ = ['ForwardGraph']


[docs]class ForwardGraph(AcyclicDirectedGraph): """The forward graph."""
[docs] def __call__(self, inputs: TENSORS, is_training: bool) -> ArgsKwargsOutput: """Forward pass through the forward graph. Args: inputs (TENSORS): Input to the graph is_training (bool): Is training or not Returns: ArgsKwargsOutput: Output of the graph """ self.graph_state.ready_for_forward(exception=True) outputs = self.calculate(inputs, is_training) return outputs
[docs] def compile(self, is_static: bool = True): """Compile the graph. Args: is_static (bool): If True, the graph is not changing during runtime and will be cached. Returns: ForwardGraph: self. Raises: ValueError: If no forward pass has been performed yet. """ if not self.graph.has_node(self.INPUT): raise ValueError("INPUT doesn't exist in the forward graph. Please preform a forward pass first.") if not self.graph.has_node(self.OUTPUT): raise ValueError("OUTPUT doesn't exist in the forward graph. Please preform a forward pass first.") return super().compile(is_static=is_static)
[docs] def calculate( self, inputs: TENSORS, is_training: bool = True, **kwargs ) -> ArgsKwargsOutput: """Calculate the output of the graph. Args: inputs (TENSORS): Input to the graph is_training (bool): Is training or not **kwargs: Additional arguments Returns: ArgsKwargsOutput: Output of the graph """ if not isinstance(inputs, (tuple, list)): inputs = (inputs,) if not self.graph_state.use_autograd_graph and is_training: value_tensor = False for i in inputs: if not isinstance(i, torch.Tensor): continue i.requires_grad = True value_tensor = True if not value_tensor: raise ValueError('At least one input must be a tensor.') input_output_graph = self._pass(self.INPUT, *inputs) if is_training: self.graph_state.forward_input_output_graph = input_output_graph outputs = input_output_graph[self.OUTPUT].outputs return ArgsKwargs.from_args_kwargs_object(outputs)
[docs] def _pass(self, from_node: GraphEnum, *inputs: Tensor) -> Dict[GraphEnum, InputOutput]: """Perform the forward pass through the graph. Args: from_node (GraphEnum): The node to start the forward pass from *inputs (Tensor): Input to the graph Returns: Dict[GraphEnum, InputOutput]: The input and output of each node """ static_graph = self._create_static_sub_graph(from_node) input_output_graph = { from_node: InputOutput(inputs=ArgsKwargs(args=[*inputs])) } for module, predecessors in static_graph: if module != from_node: inputs = self.parse_args_kwargs(input_output_graph, module, predecessors) if not self.graph_state.use_autograd_graph: inputs.args = [ self._detach_tensor(i) if isinstance(i, torch.Tensor) else i for i in inputs.args ] inputs.kwargs = { k: self._detach_tensor(v) if isinstance(v, torch.Tensor) else v for k, v in inputs.kwargs.items() } input_output_graph[module] = InputOutput(inputs=inputs) if isinstance(module, GraphEnum): input_output_graph[module].outputs = input_output_graph[module].inputs continue outputs = module( *input_output_graph[module].inputs.args, **input_output_graph[module].inputs.kwargs ) input_output_graph[module].outputs = ArgsKwargs.to_args_kwargs_object(outputs) return input_output_graph
@staticmethod
[docs] def _detach_tensor(tensor: torch.Tensor) -> torch.Tensor: """Detach the tensor from the autograd graph. Args: tensor (torch.Tensor): Tensor to detach Returns: torch.Tensor: Detached tensor """ tensor: torch.Tensor = tensor.detach() tensor.requires_grad = True return tensor