Source code for analogvnn.nn.module.Sequential

from __future__ import annotations

from typing import TypeVar, Optional

import torch
from torch import nn

from analogvnn.nn.module.Model import Model

T = TypeVar('T', bound=nn.Module)

__all__ = ['Sequential']


[docs]class Sequential(Model, nn.Sequential): """Base class for all sequential models."""
[docs] def __call__(self, *args, **kwargs): """Call the model. Args: *args: The input. **kwargs: The input. Returns: torch.Tensor: The output of the model. """ if not self._compiled: self.compile() return super().__call__(*args, **kwargs)
[docs] def compile(self, device: Optional[torch.device] = None, layer_data: bool = True): """Compile the model and add forward graph. Args: device (torch.device): The device to run the model on. layer_data (bool): True if the data of the layers should be compiled. Returns: Sequential: self """ arr = [self.graphs.INPUT, *list(self.registered_children()), self.graphs.OUTPUT] self.graphs.forward_graph.add_connection(*arr) return super().compile(device, layer_data)
[docs] def add_sequence(self, *args): """Add a sequence of modules to the forward graph of model. Args: *args (nn.Module): The modules to add. """ return self.extend(args)