analogvnn.fn.train
#
Module Contents#
Functions#
|
Train the model on the train set. |
- analogvnn.fn.train.train(model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, epoch: Optional[int] = None, test_run: bool = False) Tuple[float, float] [source]#
Train the model on the train set.
- Parameters:
model (torch.nn.Module) – the model to train.
train_loader (DataLoader) – the train set.
epoch (int) – the current epoch.
test_run (bool) – is it a test run.
- Returns:
the loss and accuracy of the model on the train set.
- Return type: