Source code for analogvnn.fn.to_matrix

from torch import Tensor

__all__ = ['to_matrix']


[docs]def to_matrix(tensor: Tensor) -> Tensor: """`to_matrix` takes a tensor and returns a matrix with the same values as the tensor. Args: tensor (Tensor): Tensor Returns: Tensor: Tensor with the same values as the tensor, but with shape (1, -1). """ if len(tensor.size()) == 1: return tensor.view(1, -1) return tensor