Source code for analogvnn.utils.to_tensor_parameter

from typing import Tuple, Union

import torch
from torch import nn

__all__ = ['to_float_tensor', 'to_nongrad_parameter']


[docs]def to_float_tensor(*args) -> Tuple[Union[torch.Tensor, None], ...]: """Converts the given arguments to `torch.Tensor` of type `torch.float32`. The returned tensors are not trainable. Args: *args: the arguments to convert. Returns: tuple: the converted arguments. """ return tuple((None if i is None else torch.tensor(i, requires_grad=False, dtype=torch.float)) for i in args)
[docs]def to_nongrad_parameter(*args) -> Tuple[Union[nn.Parameter, None], ...]: """Converts the given arguments to `nn.Parameter` of type `torch.float32`. The returned parameters are not trainable. Args: *args: the arguments to convert. Returns: tuple: the converted arguments. """ return tuple((None if i is None else nn.Parameter(i, requires_grad=False)) for i in to_float_tensor(*args))