analogvnn.fn.train#

Module Contents#

Functions#

train(→ Tuple[float, float])

Train the model on the train set.

analogvnn.fn.train.train(model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, epoch: Optional[int] = None, test_run: bool = False) Tuple[float, float][source]#

Train the model on the train set.

Parameters:
  • model (torch.nn.Module) – the model to train.

  • train_loader (DataLoader) – the train set.

  • epoch (int) – the current epoch.

  • test_run (bool) – is it a test run.

Returns:

the loss and accuracy of the model on the train set.

Return type:

tuple