Source code for analogvnn.nn.module.FullSequential
from __future__ import annotations
from typing import Optional
import torch
from analogvnn.nn.module.Sequential import Sequential
__all__ = ['FullSequential']
[docs]class FullSequential(Sequential):
"""A sequential model where backward graph is the reverse of forward graph."""
[docs] def compile(self, device: Optional[torch.device] = None, layer_data: bool = True):
"""Compile the model and add forward and backward graph.
Args:
device (torch.device): The device to run the model on.
layer_data (bool): True if the data of the layers should be compiled.
Returns:
FullSequential: self
"""
arr = [self.graphs.INPUT, *list(self.registered_children()), self.graphs.OUTPUT]
self.graphs.backward_graph.add_connection(*reversed(arr))
return super().compile(device, layer_data)