from __future__ import annotations
from typing import Dict
import torch
from torch import Tensor
from analogvnn.graph.AcyclicDirectedGraph import AcyclicDirectedGraph
from analogvnn.graph.ArgsKwargs import ArgsKwargs, InputOutput, ArgsKwargsOutput
from analogvnn.graph.GraphEnum import GraphEnum
from analogvnn.utils.common_types import TENSORS
__all__ = ['ForwardGraph']
[docs]class ForwardGraph(AcyclicDirectedGraph):
"""The forward graph."""
[docs] def __call__(self, inputs: TENSORS, is_training: bool) -> ArgsKwargsOutput:
"""Forward pass through the forward graph.
Args:
inputs (TENSORS): Input to the graph
is_training (bool): Is training or not
Returns:
ArgsKwargsOutput: Output of the graph
"""
self.graph_state.ready_for_forward(exception=True)
outputs = self.calculate(inputs, is_training)
return outputs
[docs] def compile(self, is_static: bool = True):
"""Compile the graph.
Args:
is_static (bool): If True, the graph is not changing during runtime and will be cached.
Returns:
ForwardGraph: self.
Raises:
ValueError: If no forward pass has been performed yet.
"""
if not self.graph.has_node(self.INPUT):
raise ValueError("INPUT doesn't exist in the forward graph. Please preform a forward pass first.")
if not self.graph.has_node(self.OUTPUT):
raise ValueError("OUTPUT doesn't exist in the forward graph. Please preform a forward pass first.")
return super().compile(is_static=is_static)
[docs] def calculate(
self,
inputs: TENSORS,
is_training: bool = True,
**kwargs
) -> ArgsKwargsOutput:
"""Calculate the output of the graph.
Args:
inputs (TENSORS): Input to the graph
is_training (bool): Is training or not
**kwargs: Additional arguments
Returns:
ArgsKwargsOutput: Output of the graph
"""
if not isinstance(inputs, (tuple, list)):
inputs = (inputs,)
if not self.graph_state.use_autograd_graph and is_training:
value_tensor = False
for i in inputs:
if not isinstance(i, torch.Tensor):
continue
i.requires_grad = True
value_tensor = True
if not value_tensor:
raise ValueError('At least one input must be a tensor.')
input_output_graph = self._pass(self.INPUT, *inputs)
if is_training:
self.graph_state.forward_input_output_graph = input_output_graph
outputs = input_output_graph[self.OUTPUT].outputs
return ArgsKwargs.from_args_kwargs_object(outputs)
[docs] def _pass(self, from_node: GraphEnum, *inputs: Tensor) -> Dict[GraphEnum, InputOutput]:
"""Perform the forward pass through the graph.
Args:
from_node (GraphEnum): The node to start the forward pass from
*inputs (Tensor): Input to the graph
Returns:
Dict[GraphEnum, InputOutput]: The input and output of each node
"""
static_graph = self._create_static_sub_graph(from_node)
input_output_graph = {
from_node: InputOutput(inputs=ArgsKwargs(args=[*inputs]))
}
for module, predecessors in static_graph:
if module != from_node:
inputs = self.parse_args_kwargs(input_output_graph, module, predecessors)
if not self.graph_state.use_autograd_graph:
inputs.args = [
self._detach_tensor(i) if isinstance(i, torch.Tensor) else i
for i in inputs.args
]
inputs.kwargs = {
k: self._detach_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in inputs.kwargs.items()
}
input_output_graph[module] = InputOutput(inputs=inputs)
if isinstance(module, GraphEnum):
input_output_graph[module].outputs = input_output_graph[module].inputs
continue
outputs = module(
*input_output_graph[module].inputs.args,
**input_output_graph[module].inputs.kwargs
)
input_output_graph[module].outputs = ArgsKwargs.to_args_kwargs_object(outputs)
return input_output_graph
@staticmethod
[docs] def _detach_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Detach the tensor from the autograd graph.
Args:
tensor (torch.Tensor): Tensor to detach
Returns:
torch.Tensor: Detached tensor
"""
tensor: torch.Tensor = tensor.detach()
tensor.requires_grad = True
return tensor