Source code for analogvnn.utils.get_model_summaries

from typing import Optional, Sequence, Tuple

from torch import nn
from import DataLoader

from analogvnn.nn.module.Layer import Layer

[docs]def get_model_summaries( # noqa: C901 model: Optional[nn.Module], input_size: Optional[Sequence[int]] = None, train_loader: DataLoader = None, *args, **kwargs ) -> Tuple[str, str]: """Creates the model summaries. Args: train_loader (DataLoader): the train loader. model (nn.Module): the model to log. input_size (Optional[Sequence[int]]): the input size. *args: the arguments to torchinfo.summary. **kwargs: the keyword arguments to torchinfo.summary. Returns: Tuple[str, str]: the model __repr__ and the model summary. Raises: ImportError: if torchinfo ( is not installed. ValueError: if the input_size and train_loader are None. """ try: import torchinfo except ImportError as e: raise ImportError('requires torchinfo:') from e if input_size is None and train_loader is None and 'input_size' not in kwargs: raise ValueError('input_size or train_loader must be provided') if 'input_size' not in kwargs: if input_size is None: data_shape = list(next(iter(train_loader))[0].shape) if train_loader.batch_size > 0: data_shape[0] = 1 input_size = data_shape kwargs['input_size'] = input_size use_autograd_graph = False if isinstance(model, Layer): use_autograd_graph = model.use_autograd_graph model.use_autograd_graph = True if 'depth' not in kwargs: kwargs['depth'] = 10 if 'col_names' not in kwargs: kwargs['col_names'] = tuple(e.value for e in torchinfo.ColumnSettings) if 'verbose' not in kwargs: kwargs['verbose'] = torchinfo.Verbosity.QUIET model_summary = torchinfo.summary( model, *args, **kwargs, ) if isinstance(model, Layer): model.use_autograd_graph = use_autograd_graph model_summary.formatting.verbose = torchinfo.Verbosity.VERBOSE model_str = str(model) model_summary = f'{model_summary}' return model_str, model_summary