Source code for analogvnn.fn.reduce_precision
import torch
from torch import Tensor
from analogvnn.utils.common_types import TENSOR_OPERABLE
__all__ = ['reduce_precision', 'stochastic_reduce_precision']
[docs]def reduce_precision(x: TENSOR_OPERABLE, precision: TENSOR_OPERABLE, divide: TENSOR_OPERABLE) -> TENSOR_OPERABLE:
"""Takes `x` and reduces its precision to `precision` by rounding to the nearest multiple of `precision`.
Args:
x (TENSOR_OPERABLE): Tensor
precision (TENSOR_OPERABLE): the precision of the quantization.
divide (TENSOR_OPERABLE): the number of bits to be reduced
Returns:
TENSOR_OPERABLE: TENSOR_OPERABLE with the same shape as x, but with values rounded to the nearest
multiple of precision.
"""
x = x if isinstance(x, Tensor) else torch.tensor(x, requires_grad=False)
g: Tensor = x * precision
f = torch.sign(g) * torch.maximum(
torch.floor(torch.abs(g)),
torch.ceil(torch.abs(g) - divide)
) * (1 / precision)
return f
[docs]def stochastic_reduce_precision(x: TENSOR_OPERABLE, precision: TENSOR_OPERABLE) -> TENSOR_OPERABLE:
"""Takes `x` and reduces its precision by rounding to the nearest multiple of `precision` with stochastic scheme.
Args:
x (TENSOR_OPERABLE): Tensor
precision (TENSOR_OPERABLE): the precision of the quantization.
Returns:
TENSOR_OPERABLE: TENSOR_OPERABLE with the same shape as x, but with values rounded to the
nearest multiple of precision.
"""
g: Tensor = x * precision
rand_x = torch.rand_like(g, requires_grad=False)
g_abs = torch.abs(g)
g_floor = torch.floor(g_abs)
g_ceil = torch.ceil(g_abs)
prob_floor = 1 - torch.abs(g_floor - g_abs)
bool_floor = rand_x <= prob_floor
do_floor = bool_floor.type(torch.float)
do_ceil = torch.logical_not(bool_floor).type(torch.float)
f = torch.sign(g) * (do_floor * g_floor + do_ceil * g_ceil) * (1 / precision)
return f