Source code for analogvnn.nn.precision.ReducePrecision
import torch
from torch import nn, Tensor
from analogvnn.backward.BackwardIdentity import BackwardIdentity
from analogvnn.fn.reduce_precision import reduce_precision
from analogvnn.nn.precision.Precision import Precision
from analogvnn.utils.common_types import TENSOR_OPERABLE
__all__ = ['ReducePrecision']
[docs]class ReducePrecision(Precision, BackwardIdentity):
"""Implements the reduce precision function.
Attributes:
precision (nn.Parameter): the precision of the output tensor.
divide (nn.Parameter): the rounding value that is if divide is 0.5,
then 0.6 will be rounded to 1.0 and 0.4 will be rounded to 0.0.
"""
def __init__(self, precision: int = None, divide: float = 0.5):
"""Initialize the reduce precision function.
Args:
precision (int): the precision of the output tensor.
divide (float): the rounding value that is if divide is 0.5,
then 0.6 will be rounded to 1.0 and 0.4 will be rounded to 0.0.
"""
super().__init__()
if precision < 1:
raise ValueError(f'precision has to be more than 0, but got {precision}')
if precision != int(precision):
raise ValueError(f'precision must be int, but got {precision}')
if not (0 <= divide <= 1):
raise ValueError(f'divide must be between 0 and 1, but got {divide}')
self.precision = nn.Parameter(torch.tensor(precision), requires_grad=False)
self.divide = nn.Parameter(torch.tensor(divide), requires_grad=False)
@property
[docs] def precision_width(self) -> Tensor:
"""The precision width.
Returns:
Tensor: the precision width
"""
return 1 / self.precision
@property
[docs] def bit_precision(self) -> Tensor:
"""The bit precision of the ReducePrecision module.
Returns:
Tensor: the bit precision of the ReducePrecision module.
"""
return torch.log2(self.precision + 1)
@staticmethod
[docs] def convert_to_precision(bit_precision: TENSOR_OPERABLE) -> TENSOR_OPERABLE:
"""Convert the bit precision to the precision.
Args:
bit_precision (TENSOR_OPERABLE): the bit precision.
Returns:
TENSOR_OPERABLE: the precision.
"""
return 2 ** bit_precision - 1
[docs] def extra_repr(self) -> str:
"""The extra __repr__ string of the ReducePrecision module.
Returns:
str: string
"""
return f'precision={int(self.precision)}, divide={float(self.divide):0.2f}'
[docs] def forward(self, x: Tensor) -> Tensor:
"""Forward function of the ReducePrecision module.
Args:
x (Tensor): the input tensor.
Returns:
Tensor: the output tensor.
"""
return reduce_precision(x, self.precision, self.divide)