Source code for analogvnn.nn.module.Model

from __future__ import annotations

import typing
from typing import Optional, Tuple, Set, Iterator

import torch
from torch import optim, Tensor, nn
from torch.utils.data import DataLoader

from analogvnn.backward.BackwardModule import BackwardModule
from analogvnn.fn.test import test
from analogvnn.fn.train import train
from analogvnn.graph.BackwardGraph import BackwardGraph
from analogvnn.graph.ForwardGraph import ForwardGraph
from analogvnn.graph.ModelGraph import ModelGraph
from analogvnn.nn.module.Layer import Layer
from analogvnn.utils.common_types import TENSORS, TENSOR_CALLABLE
from analogvnn.utils.is_cpu_cuda import is_cpu_cuda

if typing.TYPE_CHECKING:
    from analogvnn.utils.TensorboardModelLog import TensorboardModelLog

__all__ = ['Model']


[docs]class Model(Layer, BackwardModule): """Base class for analog neural network models. Attributes: _compiled (bool): True if the model is compiled. tensorboard (TensorboardModelLog): The tensorboard logger of the model. graphs (ModelGraph): The graph of the model. forward_graph (ForwardGraph): The forward graph of the model. backward_graph (BackwardGraph): The backward graph of the model. optimizer (optim.Optimizer): The optimizer of the model. loss_function (Optional[TENSOR_CALLABLE]): The loss function of the model. accuracy_function (Optional[TENSOR_CALLABLE]): The accuracy function of the model. device (torch.device): The device of the model. """
[docs] __constants__ = ['device']
[docs] _compiled: bool
[docs] tensorboard: Optional[TensorboardModelLog]
[docs] graphs: ModelGraph
[docs] forward_graph: ForwardGraph
[docs] backward_graph: BackwardGraph
[docs] optimizer: Optional[optim.Optimizer]
[docs] loss_function: Optional[TENSOR_CALLABLE]
[docs] accuracy_function: Optional[TENSOR_CALLABLE]
[docs] device: torch.device
def __init__(self, tensorboard_log_dir=None, device=is_cpu_cuda.device): """Create a new model. Args: tensorboard_log_dir (str): The log directory of the tensorboard logger. device (torch.device): The device to run the model on. """ super().__init__() self._compiled = False self.tensorboard = None if tensorboard_log_dir is not None: self.create_tensorboard(tensorboard_log_dir) self.graphs = ModelGraph() self.forward_graph = self.graphs.forward_graph self.backward_graph = self.graphs.backward_graph self.optimizer = None self.loss_function = None self.accuracy_function = None self.device = device
[docs] def __call__(self, *args, **kwargs): """Call the model. Args: *args: The arguments of the model. **kwargs: The keyword arguments of the model. Returns: TENSORS: The output of the model. Raises: RuntimeError: if the model is not compiled. """ if not self._compiled: raise RuntimeError('Model is not compiled yet.') return super().__call__(*args, **kwargs)
@property
[docs] def use_autograd_graph(self): """Is the autograd graph used for the model. Returns: bool: If True, the autograd graph is used to calculate the gradients. """ return self.graphs.use_autograd_graph
@use_autograd_graph.setter def use_autograd_graph(self, use_autograd_graph: bool): """Set if the autograd graph is used for the model. Args: use_autograd_graph (bool): If True, the autograd graph is used to calculate the gradients. """ self.graphs.use_autograd_graph = use_autograd_graph
[docs] def named_registered_children( self, memo: Optional[Set[nn.Module]] = None, ) -> Iterator[Tuple[str, nn.Module]]: """Returns an iterator over registered modules under self. Args: memo: a memo to store the set of modules already added to the result Yields: (str, nn.Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. """ if memo is None: memo = set() memo.add(self.optimizer) memo.add(self.loss_function) memo.add(self.accuracy_function) return super().named_registered_children(memo=memo)
[docs] def compile(self, device: Optional[torch.device] = None, layer_data: bool = True): """Compile the model. Args: device (torch.device): The device to run the model on. layer_data (bool): If True, the layer data is logged. Returns: Model: The compiled model. """ if device is not None: self.device = device self.graphs.compile() for i in self.modules(): if isinstance(i, Layer) and i != self: i.use_autograd_graph = self.use_autograd_graph self.to(device=self.device) self._compiled = True if self.tensorboard is not None: self.tensorboard.on_compile(layer_data=layer_data) return self
[docs] def forward(self, *inputs: Tensor) -> TENSORS: """Forward pass of the model. Args: *inputs (Tensor): The inputs of the model. Returns: TENSORS: The output of the model. """ return self.graphs.forward_graph(inputs, self.training)
@torch.no_grad()
[docs] def backward(self, *inputs: Tensor) -> TENSORS: """Backward pass of the model. Args: *inputs (Tensor): The inputs of the model. Returns: TENSORS: The output of the model. """ return self.graphs.backward_graph(inputs)
[docs] def loss(self, output: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Calculate the loss of the model. Args: output (Tensor): The output of the model. target (Tensor): The target of the model. Returns: Tuple[Tensor, Tensor]: The loss and the accuracy of the model. Raises: ValueError: if loss_function is None. """ if self.loss_function is None: raise ValueError('loss_function is None') loss_result = self.loss_function(output, target) if self.training: self.graphs.set_loss(loss_result) accuracy_result = None if self.accuracy_function is not None: accuracy_result = self.accuracy_function(output, target) return loss_result, accuracy_result
[docs] def train_on(self, train_loader: DataLoader, epoch: int = None, *args, **kwargs) -> Tuple[float, float]: """Train the model on the train_loader. Args: train_loader (DataLoader): The train loader of the model. epoch (int): The epoch of the model. *args: The arguments of the train function. **kwargs: The keyword arguments of the train function. Returns: Tuple[float, float]: The loss and the accuracy of the model. Raises: RuntimeError: if model is not compiled. """ if self._compiled is False: raise RuntimeError('Model is not compiled') train_loss, train_accuracy = train(self, train_loader, epoch, *args, **kwargs) if self.tensorboard is not None: self.tensorboard.add_graph(train_loader) self.tensorboard.register_training(epoch, train_loss, train_accuracy) return train_loss, train_accuracy
[docs] def test_on(self, test_loader: DataLoader, epoch: int = None, *args, **kwargs) -> Tuple[float, float]: """Test the model on the test_loader. Args: test_loader (DataLoader): The test loader of the model. epoch (int): The epoch of the model. *args: The arguments of the test function. **kwargs: The keyword arguments of the test function. Returns: Tuple[float, float]: The loss and the accuracy of the model. Raises: RuntimeError: if model is not compiled. """ if self._compiled is False: raise RuntimeError('Model is not compiled') test_loss, test_accuracy = test(self, test_loader, *args, **kwargs) if self.tensorboard is not None: self.tensorboard.add_graph(test_loader) self.tensorboard.register_testing(epoch, test_loss, test_accuracy) return test_loss, test_accuracy
[docs] def fit( self, train_loader: DataLoader, test_loader: DataLoader, epoch: int = None ) -> Tuple[float, float, float, float]: """Fit the model on the train_loader and test the model on the test_loader. Args: train_loader (DataLoader): The train loader of the model. test_loader (DataLoader): The test loader of the model. epoch (int): The epoch of the model. Returns: Tuple[float, float, float, float]: The train loss, the train accuracy, the test loss and the test accuracy of the model. """ train_loss, train_accuracy = self.train_on(train_loader=train_loader, epoch=epoch) test_loss, test_accuracy = self.test_on(test_loader=test_loader, epoch=epoch) return train_loss, train_accuracy, test_loss, test_accuracy
[docs] def create_tensorboard(self, log_dir: str) -> TensorboardModelLog: """Create a tensorboard. Args: log_dir (str): The log directory of the tensorboard. Raises: ImportError: if tensorboard (https://www.tensorflow.org/) is not installed. """ try: from analogvnn.utils.TensorboardModelLog import TensorboardModelLog except ImportError as e: raise ImportError('requires tensorboard https://www.tensorflow.org/') from e self.tensorboard = TensorboardModelLog(self, log_dir=log_dir) self.subscribe_tensorboard(self.tensorboard) return self.tensorboard
[docs] def subscribe_tensorboard(self, tensorboard: TensorboardModelLog): """Subscribe the model to the tensorboard. Args: tensorboard (TensorboardModelLog): The tensorboard of the model. Returns: Model: self. """ self.tensorboard = tensorboard if self._compiled is True: self.tensorboard.on_compile() return self