Source code for analogvnn.fn.test

from typing import Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader

__all__ = ['test']


[docs]def test(model: nn.Module, test_loader: DataLoader, test_run: bool = False) -> Tuple[float, float]: """Test the model on the test set. Args: model (torch.nn.Module): the model to test. test_loader (DataLoader): the test set. test_run (bool): is it a test run. Returns: tuple: the loss and accuracy of the model on the test set. """ model.eval() total_loss = 0 total_accuracy = 0 total_size = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(model.device), target.to(model.device) output = model(data) loss, accuracy = model.loss(output, target) total_loss += loss.item() * len(data) total_accuracy += accuracy * len(data) total_size += len(data) if test_run: break total_loss /= total_size total_accuracy /= total_size return total_loss, total_accuracy