from __future__ import annotations
import os
import re
from pathlib import Path
from typing import Optional, Sequence, Tuple, Dict
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from analogvnn.nn.module.Layer import Layer
from analogvnn.nn.module.Model import Model
__all__ = ['TensorboardModelLog']
from analogvnn.utils.get_model_summaries import get_model_summaries
[docs]class TensorboardModelLog:
"""Tensorboard model log.
Attributes:
model (nn.Module): the model to log.
tensorboard (SummaryWriter): the tensorboard.
layer_data (bool): whether to log the layer data.
_log_record (Dict[str, bool]): the log record.
"""
[docs] tensorboard: Optional[SummaryWriter]
[docs] _log_record: Dict[str, bool]
def __init__(self, model: Model, log_dir: str):
"""Log the model to Tensorboard.
Args:
model (nn.Module): the model to log.
log_dir (str): the directory to log to.
"""
super().__init__()
self.model = model
self.tensorboard = None
self.layer_data = True
self._log_record = {}
if not os.path.exists(log_dir):
os.mkdir(log_dir)
self.set_log_dir(log_dir)
if hasattr(model, 'subscribe_tensorboard'):
model.subscribe_tensorboard(tensorboard=self)
[docs] def set_log_dir(self, log_dir: str) -> TensorboardModelLog:
"""Set the log directory.
Args:
log_dir (str): the log directory.
Returns:
TensorboardModelLog: self.
Raises:
ValueError: if the log directory is invalid.
"""
# https://github.com/tensorflow/tensorboard/pull/6135
from tensorboard.compat import tf
if getattr(tf, 'io', None) is None:
import tensorboard.compat.tensorflow_stub as new_tf
tf.__dict__.update(new_tf.__dict__)
if os.path.isdir(log_dir):
self.tensorboard = SummaryWriter(log_dir=log_dir)
else:
raise ValueError(f'Log directory {log_dir} does not exist.')
return self
[docs] def _add_layer_data(self, epoch: int = None):
"""Add the layer data to the tensorboard.
Args:
epoch (int): the epoch to add the data for.
"""
for name, parameter in self.model.named_parameters():
if not parameter.requires_grad:
continue
self.tensorboard.add_histogram(name, parameter.data, epoch)
[docs] def on_compile(self, layer_data: bool = True):
"""Called when the model is compiled.
Args:
layer_data (bool): whether to log the layer data.
"""
if self.layer_data:
self.layer_data = layer_data
if self.layer_data:
self._add_layer_data(epoch=-1)
return self
[docs] def add_graph(
self,
train_loader: DataLoader,
model: Optional[nn.Module] = None,
input_size: Optional[Sequence[int]] = None,
) -> TensorboardModelLog:
"""Add the model graph to the tensorboard.
Args:
train_loader (DataLoader): the train loader.
model (Optional[nn.Module]): the model to log.
input_size (Optional[Sequence[int]]): the input size.
Returns:
TensorboardModelLog: self.
"""
if model is None:
model = self.model
log_id = f'{self.tensorboard.log_dir}_{TensorboardModelLog.add_graph.__name__}_{id(model)}'
if log_id in self._log_record:
return self
if input_size is None:
data_shape = next(iter(train_loader))[0].shape
input_size = [1] + list(data_shape)[1:]
use_autograd_graph = False
if isinstance(model, Layer):
use_autograd_graph = model.use_autograd_graph
model.use_autograd_graph = False
graph_path = Path(self.tensorboard.log_dir).joinpath(f'graph_{model.__class__.__name__}_{id(model)}')
with SummaryWriter(log_dir=str(graph_path)) as graph_writer:
graph_writer.add_graph(model, torch.zeros(input_size).to(model.device))
self._log_record[log_id] = True
if isinstance(model, Layer):
model.use_autograd_graph = use_autograd_graph
return self
[docs] def add_summary(
self,
input_size: Optional[Sequence[int]] = None,
train_loader: Optional[DataLoader] = None,
model: Optional[nn.Module] = None,
*args,
**kwargs
) -> Tuple[str, str]:
"""Add the model summary to the tensorboard.
Args:
input_size (Optional[Sequence[int]]): the input size.
train_loader (Optional[DataLoader]): the train loader.
model (nn.Module): the model to log.
*args: the arguments to torchinfo.summary.
**kwargs: the keyword arguments to torchinfo.summary.
Returns:
Tuple[str, str]: the model __repr__ and the model summary.
"""
if model is None:
model = self.model
log_id = f'{self.tensorboard.log_dir}_{TensorboardModelLog.add_summary.__name__}_{id(model)}'
model_str, nn_model_summary = get_model_summaries(
model=model,
input_size=input_size,
train_loader=train_loader,
*args, # noqa: B026
**kwargs
)
if log_id in self._log_record:
return model_str, nn_model_summary
self.tensorboard.add_text(
f'str ({model.__class__.__name__})',
re.sub('\n', '\n ', f' {model_str}')
)
self.tensorboard.add_text(
f'summary ({model.__class__.__name__})',
re.sub('\n', '\n ', f' {nn_model_summary}')
)
self._log_record[log_id] = True
return model_str, nn_model_summary
[docs] def register_training(self, epoch: int, train_loss: float, train_accuracy: float) -> TensorboardModelLog:
"""Register the training data.
Args:
epoch (int): the epoch.
train_loss (float): the training loss.
train_accuracy (float): the training accuracy.
Returns:
TensorboardModelLog: self.
"""
self.tensorboard.add_scalar('Loss/train', train_loss, epoch)
self.tensorboard.add_scalar('Accuracy/train', train_accuracy, epoch)
if self.layer_data:
self._add_layer_data(epoch=epoch)
return self
[docs] def register_testing(self, epoch: int, test_loss: float, test_accuracy: float) -> TensorboardModelLog:
"""Register the testing data.
Args:
epoch (int): the epoch.
test_loss (float): the test loss.
test_accuracy (float): the test accuracy.
Returns:
TensorboardModelLog: self.
"""
self.tensorboard.add_scalar('Loss/test', test_loss, epoch)
self.tensorboard.add_scalar('Accuracy/test', test_accuracy, epoch)
return self
# noinspection PyUnusedLocal
[docs] def close(self, *args, **kwargs):
"""Close the tensorboard.
Args:
*args: ignored.
**kwargs: ignored.
"""
if self.tensorboard is not None:
self.tensorboard.close()
self.tensorboard = None
[docs] def __enter__(self):
"""Enter the TensorboardModelLog context.
Returns:
TensorboardModelLog: self.
"""
return self
"""Close the tensorboard."""