Source code for analogvnn.fn.train
from typing import Optional, Tuple
from torch import nn
from torch.utils.data import DataLoader
__all__ = ['train']
[docs]def train(
model: nn.Module,
train_loader: DataLoader,
epoch: Optional[int] = None,
test_run: bool = False
) -> Tuple[float, float]:
"""Train the model on the train set.
Args:
model (torch.nn.Module): the model to train.
train_loader (DataLoader): the train set.
epoch (int): the current epoch.
test_run (bool): is it a test run.
Returns:
tuple: the loss and accuracy of the model on the train set.
"""
model.train()
total_loss = 0.0
total_accuracy = 0
total_size = 0
if isinstance(train_loader, DataLoader):
# noinspection PyTypeChecker
dataset_size = len(train_loader.dataset)
else:
dataset_size = len(train_loader)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(model.device), target.to(model.device)
# zero the parameter gradients
model.zero_grad()
model.optimizer.zero_grad()
# forward + backward + optimize
output = model(data)
loss, accuracy = model.loss(output, target)
model.backward()
model.optimizer.step()
# print statistics
total_loss += loss.item() * len(data)
total_accuracy += accuracy * len(data)
total_size += len(data)
print_mod = int(dataset_size / (len(data) * 5))
if print_mod > 0 and batch_idx % print_mod == 0 and batch_idx > 0:
print(
f'Train Epoch:'
f' {((epoch + 1) if epoch is not None else "")}'
f' [{batch_idx * len(data)}/{dataset_size} ({100. * batch_idx / len(train_loader):.0f}%)]'
f'\tLoss: {total_loss / total_size:.6f}'
f'\tAccuracy: {total_accuracy / total_size * 100:.2f}%'
)
if test_run:
break
total_loss /= total_size
total_accuracy /= total_size
return total_loss, total_accuracy