Source code for analogvnn.graph.BackwardGraph

from __future__ import annotations

import uuid
from typing import Dict, Union, Callable, List, Tuple

import networkx as nx
import torch
from torch import nn

from analogvnn.backward.BackwardModule import BackwardModule
from analogvnn.graph.AccumulateGrad import AccumulateGrad
from analogvnn.graph.AcyclicDirectedGraph import AcyclicDirectedGraph
from analogvnn.graph.ArgsKwargs import ArgsKwargs, InputOutput, ArgsKwargsOutput
from analogvnn.graph.GraphEnum import GraphEnum, GRAPH_NODE_TYPE
from analogvnn.nn.module.Layer import Layer
from analogvnn.utils.common_types import TENSORS

__all__ = ['BackwardGraph']


[docs]class BackwardGraph(AcyclicDirectedGraph): """The backward graph."""
[docs] def __call__(self, gradient: TENSORS = None) -> ArgsKwargsOutput: """Backward pass through the backward graph. Args: gradient (TENSORS): gradient of the loss function w.r.t. the output of the forward graph Returns: ArgsKwargsOutput: gradient of the inputs function w.r.t. loss """ self.graph_state.ready_for_backward(exception=True) loss = self.graph_state.loss self.graph_state.set_loss(None) if loss is None: loss = self.graph_state.outputs.args if not isinstance(loss, (tuple, list)): loss = [loss] if len(gradient) == 0: gradient = (None,) * len(loss) if self.graph_state.use_autograd_graph: result = tuple(v.backward(gradient=gradient[i]) for i, v in enumerate(loss)) else: grad_outputs = torch.autograd.grad( outputs=loss, inputs=self.graph_state.outputs.args, grad_outputs=gradient, retain_graph=True ) result = self.calculate(*grad_outputs) return result
[docs] def compile(self, is_static=True): """Compile the graph. Args: is_static (bool): If True, the graph is not changing during runtime and will be cached. Returns: BackwardGraph: self. Raises: ValueError: If no forward pass has been performed yet. """ if not self.graph.has_node(self.OUTPUT): raise ValueError("OUTPUT doesn't exist in the forward graph. Please preform a forward pass first.") return super().compile(is_static=is_static)
[docs] def from_forward(self, forward_graph: Union[AcyclicDirectedGraph, nx.DiGraph]) -> BackwardGraph: # noqa: C901 """Create a backward graph from inverting forward graph. Args: forward_graph (Union[AcyclicDirectedGraph, nx.DiGraph]): The forward graph. Returns: BackwardGraph: self. """ if isinstance(forward_graph, AcyclicDirectedGraph): forward_graph = forward_graph.graph graph = forward_graph.reverse(copy=True) for _, _, attr in graph.edges(data=True): attr['in_arg'], attr['out_arg'] = attr['out_arg'], attr['in_arg'] attr['in_kwarg'], attr['out_kwarg'] = attr['out_kwarg'], attr['in_kwarg'] attr['label'] = AcyclicDirectedGraph._create_edge_label(**attr) new_graph = nx.MultiDiGraph() for v in graph.nodes(): if v == self.OUTPUT: continue all_predecessors = list(graph.predecessors(v)) if len(all_predecessors) == 1 and len(graph.get_edge_data(all_predecessors[0], v)) == 1: attr = graph.get_edge_data(all_predecessors[0], v)[0] if attr['in_arg'] == attr['in_kwarg'] == attr['in_arg'] == attr['in_arg'] is True: new_graph.add_edge(all_predecessors[0], v, **attr) continue akc = AccumulateGrad(v) for u in all_predecessors: for _, attr in graph.get_edge_data(u, v).items(): if attr['in_arg'] is None or attr['in_kwarg'] is None: uuid_str = str(uuid.uuid4()).replace('-', '') new_graph.add_edge(u, akc, **{ 'in_arg': attr['in_arg'], 'in_kwarg': attr['in_kwarg'], 'out_arg': None, 'out_kwarg': uuid_str, 'real_label': ' '.join(attr['label'].split(' ')[:-1] + ['{' + uuid_str + '}']), 'label': attr['label'] }) akc.input_output_connections[uuid_str] = { **attr, 'from': u, } else: uuid_str = str(uuid.uuid4()).replace('-', '') new_graph.add_edge(u, akc, **{ 'in_arg': True, 'in_kwarg': None, 'out_arg': None, 'out_kwarg': uuid_str, 'real_label': '[] -> {' + uuid_str + '}', 'label': '[] -> []', }) akc.input_output_connections[uuid_str] = { **attr, 'in_kwarg': None, 'out_kwarg': None, 'from': u, } uuid_str = str(uuid.uuid4()).replace('-', '') new_graph.add_edge(u, akc, **{ 'in_arg': None, 'in_kwarg': True, 'out_arg': None, 'out_kwarg': uuid_str, 'real_label': '{} -> {' + uuid_str + '}', 'label': '{} -> {}', }) akc.input_output_connections[uuid_str] = { **attr, 'in_arg': None, 'out_arg': None, 'from': u, } new_graph.add_edge(akc, v, **{ 'in_arg': True, 'in_kwarg': True, 'out_arg': True, 'out_kwarg': True, 'len': 0, }) for v in graph.nodes(): new_graph.nodes[v]['fillcolor'] = 'lightblue' self.graph = new_graph return self
@torch.no_grad()
[docs] def calculate(self, *args, **kwargs) -> ArgsKwargsOutput: """Calculate the gradient of the whole graph w.r.t. loss. Args: *args: The gradients args of outputs. **kwargs: The gradients kwargs of outputs. Returns: ArgsKwargsOutput: The gradient of the inputs function w.r.t. loss. Raises: ValueError: If no forward pass has been performed yet. """ if self.graph_state.forward_input_output_graph is None: raise ValueError('No forward pass has been performed yet. Please preform a forward pass first.') input_output_graph = self._pass(self.OUTPUT, *args, **kwargs) self.graph_state.forward_input_output_graph = None if self.INPUT in input_output_graph: return ArgsKwargs.from_args_kwargs_object(input_output_graph[self.INPUT].outputs) else: return None
[docs] def _pass(self, from_node: GRAPH_NODE_TYPE, *args, **kwargs) -> Dict[GRAPH_NODE_TYPE, InputOutput]: """Perform the backward pass through the graph. Args: from_node (GRAPH_NODE_TYPE): The node to start the backward pass from. *args: The gradients args of outputs. **kwargs: The gradients kwargs of outputs. Returns: Dict[GRAPH_NODE_TYPE, InputOutput]: The input and output gradients of each node. """ static_graph: List[Tuple[GRAPH_NODE_TYPE, List[GRAPH_NODE_TYPE]]] = self._create_static_sub_graph(from_node) input_output_graph: Dict[GRAPH_NODE_TYPE, InputOutput] = { from_node: InputOutput(inputs=ArgsKwargs( args=args, kwargs=kwargs )) } for module, predecessors in static_graph: if module != from_node: inputs = self.parse_args_kwargs(input_output_graph, module, predecessors) input_output_graph[module] = InputOutput(inputs=inputs) if isinstance(module, GraphEnum): input_output_graph[module].outputs = input_output_graph[module].inputs continue outputs = self._calculate_gradients( module, input_output_graph[module] ) input_output_graph[module].outputs = ArgsKwargs.to_args_kwargs_object(outputs) return input_output_graph
[docs] def _calculate_gradients( # noqa: C901 self, module: Union[AccumulateGrad, Layer, BackwardModule, Callable], grad_outputs: InputOutput ) -> ArgsKwargs: """Calculate the gradient of a module w.r.t. outputs of the module using the output's gradients. Args: module (Union[AccumulateGrad, Layer, BackwardModule, Callable]): The module to calculate the gradient of. grad_outputs (InputOutput): The gradients of the output of the module. Returns: ArgsKwargs: The input gradients of the module. """ if module in self.graph_state.forward_input_output_graph: module_inputs_outputs = self.graph_state.forward_input_output_graph[module] else: module_inputs_outputs = None if grad_outputs.inputs.is_empty(): return ArgsKwargs() if isinstance(module, AccumulateGrad): return module.grad( grad_outputs_args_kwargs=grad_outputs.inputs, forward_input_output_graph=self.graph_state.forward_input_output_graph, ) if isinstance(module, Layer) and isinstance(module.backward_function, BackwardModule): module = module.backward_function if isinstance(module, BackwardModule): grad_inputs = module._call_impl_backward(*grad_outputs.inputs.args, **grad_outputs.inputs.kwargs) return ArgsKwargs.to_args_kwargs_object(grad_inputs) inputs = module_inputs_outputs.inputs.args + list(module_inputs_outputs.inputs.kwargs.values()) outputs = [] outputs_grads = [] module_parameters = [] if len(inputs) == 0: return ArgsKwargs( args=[torch.zeros_like(i) for i in module_inputs_outputs.inputs.args], kwargs={key: torch.zeros_like(value) for key, value in module_inputs_outputs.inputs.kwargs.items()} ) for i in range(len(module_inputs_outputs.outputs.args)): outputs.append(module_inputs_outputs.outputs.args[i]) outputs_grads.append(grad_outputs.inputs.args[i]) for i in module_inputs_outputs.outputs.kwargs: outputs.append(module_inputs_outputs.outputs.kwargs[i]) outputs_grads.append(grad_outputs.inputs.kwargs[i]) if isinstance(module, nn.Module): module_parameters = list(module.parameters()) inputs += module_parameters grad_dict = {id(i): None for i in inputs} filtered_inputs = [i for i in inputs if i is not None and i.requires_grad] out_grads = torch.autograd.grad( outputs=outputs, inputs=filtered_inputs, grad_outputs=outputs_grads, retain_graph=True, allow_unused=True ) for i, v in enumerate(out_grads): grad_dict[id(filtered_inputs[i])] = v for i in module_parameters: if grad_dict[id(i)] is None: continue if i.grad is None: i.grad = grad_dict[id(i)] else: i.grad += grad_dict[id(i)] return ArgsKwargs( args=[grad_dict[id(i)] for i in module_inputs_outputs.inputs.args], kwargs={key: grad_dict[id(value)] for key, value in module_inputs_outputs.inputs.kwargs.items()}
)