Example #1
0
def _calculate_dynamic_per_channel_qparams(X, dtype):
    """Calculate the dynamic quantization parameters (scale, zero_point)
    according to the min and max element of the tensor"""
    if isinstance(X, torch.Tensor):
        X = X.numpy()
    qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    n_levels = qmax - qmin
    scale = np.zeros(X.shape[0], dtype=np.float64)
    zero_point = np.zeros(X.shape[0], dtype=np.int64)
    for i in range(zero_point.shape[0]):
        min_val = X.min()
        max_val = X.max()
        if min_val == max_val:
            scale[i] = 1.0
            zero_point[i] = 0
        else:
            max_val = max(max_val, 0.0)
            min_val = min(min_val, 0.0)
            scale[i] = (max_val - min_val) / n_levels
            scale[i] = max(scale[i], np.finfo(np.float32).eps)
            zero_point[i] = qmin - round(min_val / scale[i])
            zero_point[i] = max(qmin, zero_point[i])
            zero_point[i] = min(qmax, zero_point[i])

    return scale, zero_point
Example #2
0
 def __init__(self,
              observer=MovingAverageMinMaxObserver,
              quant_min=0,
              quant_max=255,
              **observer_kwargs):
     super(FakeQuantize, self).__init__()
     assert quant_min <= quant_max, \
         'quant_min must be less than or equal to quant_max'
     self.quant_min = quant_min
     self.quant_max = quant_max
     # fake_quant_enabled and observer_enabled are buffers to support their
     # replication in DDP. Data type is uint8 because NCCL does not support
     # bool tensors.
     self.register_buffer('fake_quant_enabled',
                          torch.tensor([1], dtype=torch.uint8))
     self.register_buffer('observer_enabled',
                          torch.tensor([1], dtype=torch.uint8))
     self.activation_post_process = observer(**observer_kwargs)
     assert torch.iinfo(self.activation_post_process.dtype
                        ).min <= quant_min, 'quant_min out of bound'
     assert quant_max <= torch.iinfo(
         self.activation_post_process.dtype).max, 'quant_max out of bound'
     self.register_buffer('scale', torch.tensor([1.0]))
     self.register_buffer('zero_point', torch.tensor([0]))
     self.dtype = self.activation_post_process.dtype
     self.qscheme = self.activation_post_process.qscheme
     self.ch_axis = self.activation_post_process.ch_axis \
         if hasattr(self.activation_post_process, 'ch_axis') else -1
Example #3
0
 def __init__(self,
              observer=MovingAverageMinMaxObserver,
              quant_min=0,
              quant_max=255,
              **observer_kwargs):
     super().__init__()
     assert quant_min <= quant_max, \
         'quant_min must be less than or equal to quant_max'
     self.quant_min = quant_min
     self.quant_max = quant_max
     self.activation_post_process = observer(**observer_kwargs)
     assert torch.iinfo(self.activation_post_process.dtype
                        ).min <= quant_min, 'quant_min out of bound'
     assert quant_max <= torch.iinfo(
         self.activation_post_process.dtype).max, 'quant_max out of bound'
     self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
     self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
     self.dtype = self.activation_post_process.dtype
     self.qscheme = self.activation_post_process.qscheme
     self.ch_axis = self.activation_post_process.ch_axis \
         if hasattr(self.activation_post_process, 'ch_axis') else -1
     assert _is_per_channel(self.qscheme) or \
         _is_per_tensor(self.qscheme), \
         'Only per channel and per tensor quantization are supported in fake quantize' + \
         ' got qscheme: ' + str(self.qscheme)
     self.is_per_channel = _is_per_channel(self.qscheme)
Example #4
0
    def test_fq_module(self, device, X):
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        fq_module = FakeQuantize(default_per_channel_weight_observer,
                                 quant_min,
                                 quant_max,
                                 ch_axis=axis).to(device)
        Y_prime = fq_module(X)
        assert fq_module.scale is not None
        assert fq_module.zero_point is not None
        Y = _fake_quantize_per_channel_affine_reference(
            X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(Y.cpu().detach().numpy(),
                                   Y_prime.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

        # Test backward
        dout = torch.rand(X.shape, dtype=torch.float, device=device)
        Y_prime.backward(dout)
        dX = _fake_quantize_per_channel_affine_grad_reference(
            dout, X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(dX.cpu().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)
Example #5
0
def convert_image_dtype(image: torch.Tensor,
                        dtype: torch.dtype = torch.float) -> torch.Tensor:
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly

    Args:
        image (torch.Tensor): Image to be converted
        dtype (torch.dtype): Desired data type of the output

    Returns:
        (torch.Tensor): Converted image

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """
    if image.dtype == dtype:
        return image

    if image.dtype.is_floating_point:
        # float to float
        if dtype.is_floating_point:
            return image.to(dtype)

        # float to int
        if (image.dtype == torch.float32 and dtype
                in (torch.int32, torch.int64)) or (image.dtype == torch.float64
                                                   and dtype == torch.int64):
            msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
            raise RuntimeError(msg)

        eps = 1e-3
        return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
    else:
        # int to float
        if dtype.is_floating_point:
            max = torch.iinfo(image.dtype).max
            image = image.to(dtype)
            return image / max

        # int to int
        input_max = torch.iinfo(image.dtype).max
        output_max = torch.iinfo(dtype).max

        if input_max > output_max:
            factor = (input_max + 1) // (output_max + 1)
            image = image // factor
            return image.to(dtype)
        else:
            factor = (output_max + 1) // (input_max + 1)
            image = image.to(dtype)
            return image * factor
Example #6
0
    def test_stable_sort_against_numpy(self, device, dtype):
        if dtype in torch.testing.floating_types_and(torch.float16):
            inf = float('inf')
            neg_inf = -float('inf')
            nan = float('nan')
        else:
            if dtype != torch.bool:
                # no torch.iinfo support for torch.bool
                inf = torch.iinfo(dtype).max
                neg_inf = torch.iinfo(dtype).min
            else:
                inf = True
                neg_inf = ~inf
            # no nan for integral types, we use inf instead for simplicity
            nan = inf

        def generate_samples():
            from itertools import chain, combinations

            def repeated_index_fill(t, dim, idxs, vals):
                res = t
                for idx, val in zip(idxs, vals):
                    res = res.index_fill(dim, idx, val)
                return res

            for sizes in [(1, 10), (10, 1), (10, 10), (10, 10, 10)]:
                size = min(*sizes)
                x = (torch.randn(*sizes, device=device) * size).to(dtype)
                yield (x, 0)

                # Generate tensors which are being filled at random locations
                # with values from the non-empty subsets of the set (inf, neg_inf, nan)
                # for each dimension.
                n_fill_vals = 3  # cardinality of (inf, neg_inf, nan)
                for dim in range(len(sizes)):
                    idxs = (torch.randint(high=size, size=(size // 10, ))
                            for i in range(n_fill_vals))
                    vals = (inf, neg_inf, nan)
                    subsets = chain.from_iterable(
                        combinations(list(zip(idxs, vals)), r)
                        for r in range(1, n_fill_vals + 1))
                    for subset in subsets:
                        idxs_subset, vals_subset = zip(*subset)
                        yield (repeated_index_fill(x, dim, idxs_subset,
                                                   vals_subset), dim)

            for sizes in [(100, ), (1000, ), (10000, )]:
                size = sizes[0]
                # binary strings
                yield (torch.tensor([0, 1] * size, dtype=dtype,
                                    device=device), 0)

        for sample, dim in generate_samples():
            _, idx_torch = sample.sort(dim=dim, stable=True)
            sample_numpy = sample.numpy()
            idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable')
            self.assertEqual(idx_torch, idx_numpy)
Example #7
0
 def test_topk_integral(self, device, dtype):
     a = torch.randint(torch.iinfo(dtype).min,
                       torch.iinfo(dtype).max,
                       size=(10, ),
                       dtype=dtype,
                       device=device)
     sort_topk = a.sort()[0][-5:].flip(0)
     topk = a.topk(5)
     self.assertEqual(sort_topk, topk[0])  # check values
     self.assertEqual(sort_topk, a[topk[1]])  # check indices
Example #8
0
    def _test_topk_dtype(self, device, dtype, integral, size):
        if integral:
            a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max,
                              size=(size,), dtype=dtype, device=device)
        else:
            a = torch.randn(size=(size,), dtype=dtype, device=device)

        sort_topk = a.sort()[0][-(size // 2):].flip(0)
        topk = a.topk(size // 2)
        self.assertEqual(sort_topk, topk[0])      # check values
        self.assertEqual(sort_topk, a[topk[1]])   # check indices
Example #9
0
 def test_topk_integral(self, device, dtype):
     small = 10
     large = 4096
     for curr_size in (small, large):
         a = torch.randint(torch.iinfo(dtype).min,
                           torch.iinfo(dtype).max,
                           size=(curr_size, ),
                           dtype=dtype,
                           device=device)
         sort_topk = a.sort()[0][-(curr_size // 2):].flip(0)
         topk = a.topk(curr_size // 2)
         self.assertEqual(sort_topk, topk[0])  # check values
         self.assertEqual(sort_topk, a[topk[1]])  # check indices
Example #10
0
    def test_forward_per_tensor(self, device, X):
        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
Example #11
0
    def __init__(self,
                 observer,
                 quant_min=-128,
                 quant_max=127,
                 scale=1.,
                 zero_point=0.,
                 **observer_kwargs):
        """Adaround FakeQuantization module

        Args:
            observer (torch.quantization.observer): quantization obserever to initilalize quantizers
            quant_min (int, optional): quantized values range (min). Defaults to -128.
            quant_max (int, optional): quantized values range (max). Defaults to 127.
            scale (float, optional):  Defaults to 1..
            zero_point (float, optional):  Defaults to 0..
        """
        super(AdaRound, self).__init__()
        assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
        self.quant_min = quant_min
        self.quant_max = quant_max

        self.activation_post_process = observer(**observer_kwargs)
        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
               'quant_min out of bound'
        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
               'quant_max out of bound'
        self.register_buffer('scale', torch.tensor([scale]))
        self.register_buffer('zero_point', torch.tensor([zero_point]))

        self.dtype = self.activation_post_process.dtype
        self.qscheme = self.activation_post_process.qscheme
        assert self.qscheme in (torch.per_channel_symmetric,
                                torch.per_tensor_symmetric)
        self.ch_axis = self.activation_post_process.ch_axis \
            if hasattr(self.activation_post_process, 'ch_axis') else -1
        self.register_buffer('observer_enabeled',
                             torch.tensor([1], dtype=torch.uint8))

        self.register_buffer('fake_quant_enabled',
                             torch.tensor([1], dtype=torch.uint8))

        self.register_buffer('soft_sigmoid_enabled',
                             torch.tensor([0], dtype=torch.uint8))
        self.register_buffer('hard_sigmoid_enabled',
                             torch.tensor([0], dtype=torch.uint8))

        bitrange = torch.tensor(quant_max - quant_min + 1).double()
        self.bitwidth = int(torch.log2(bitrange).item())

        self.gamma, self.zeta = -0.1, 1.1
        self.beta = 2. / 3
Example #12
0
def _reduction_identity(op_name: str, input: Tensor, *args):
    """Return identity value as scalar tensor of a reduction operation on
    given input, or None, if the identity value cannot be uniquely
    defined for the given input.

    The identity value of the operation is defined as the initial
    value to reduction operation that has a property ``op(op_identity,
    value) == value`` for any value in the domain of the operation.
    Or put it another way, including or exlucing the identity value in
    a list of operands will not change the reduction result.

    See https://github.com/pytorch/rfcs/pull/27 for more information.

    """
    dtype: DType = input.dtype
    device = input.device
    op_name = op_name.rsplit('.', 1)[-1]  # lstrip module name when present
    if op_name == 'sum':
        return torch.tensor(0, dtype=dtype, device=device)
    elif op_name == 'prod':
        return torch.tensor(1, dtype=dtype, device=device)
    elif op_name == 'amax':
        if torch.is_floating_point(input):
            return torch.tensor(-torch.inf, dtype=dtype, device=device)
        elif torch.is_signed(input) or dtype == torch.uint8:
            return torch.tensor(torch.iinfo(dtype).min,
                                dtype=dtype,
                                device=device)
    elif op_name == 'amin':
        if torch.is_floating_point(input):
            return torch.tensor(torch.inf, dtype=dtype, device=device)
        elif torch.is_signed(input) or dtype == torch.uint8:
            return torch.tensor(torch.iinfo(dtype).max,
                                dtype=dtype,
                                device=device)
    elif op_name == 'mean':
        # Strictly speaking, the identity value of the mean operation
        # is the mean of the input. Since the mean value depends on
        # the dim argument and it may be a non-scalar tensor, we
        # consider the identity value of the mean operation ambiguous.
        # Moreover, the mean value of empty input is undefined.
        return None
    elif op_name == 'norm':
        ord = args[0] if args else 2
        if ord == float('-inf'):
            assert torch.is_floating_point(input), input.dtype
            return torch.tensor(torch.inf, dtype=dtype, device=device)
        return torch.tensor(0, dtype=dtype, device=device)
    elif op_name in {'var', 'std'}:
        return None
    raise NotImplementedError(f'identity of {op_name} on {dtype} input')
Example #13
0
    def _test_numerical_consistency(self, test_type):
        r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes
        """
        torch.random.manual_seed(NP_RANDOM_SEED)
        torch_types = [torch.qint8, torch.quint8]
        float_types = [torch.float, torch.float16, torch.float64]
        zero_types = [torch.long]
        devices = [torch.device('cpu'),
                   torch.device('cuda')
                   ] if torch.cuda.is_available() else [torch.device('cpu')]
        axis = 1
        for i in range(20):
            for torch_type, float_type, device, zero_type in itertools.product(
                    torch_types, float_types, devices, zero_types):
                X = torch.randn(3, 3, device=device).to(float_type)
                scales = (10 * torch.randn(3, device=device)).abs()
                scale = scales.mean().to(float).item()
                zeros = (10 * torch.randn(3, device=device)).abs().to(
                    dtype=zero_type)
                zero = zeros.max().view(1).item()
                quant_min = torch.iinfo(torch_type).min
                quant_max = torch.iinfo(torch_type).max

                test_was_run = False
                if test_type == "per_tensor":
                    test_was_run = True
                    Y = torch.dequantize(
                        torch.quantize_per_tensor(
                            X.to('cpu').to(torch.float), scale, zero,
                            torch_type)).to(device).to(float_type)
                    Y_prime = torch.fake_quantize_per_tensor_affine(
                        X, scale, zero, quant_min, quant_max)
                    self.assertEqual(
                        Y, Y_prime,
                        "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor"
                    )

                if test_type == "per_channel":
                    test_was_run = True
                    Y = torch.dequantize(
                        torch.quantize_per_channel(
                            X.to('cpu').to(torch.float), scales.to('cpu'),
                            zeros.to('cpu'), axis,
                            torch_type)).to(device).to(float_type)
                    Y_prime = torch.fake_quantize_per_channel_affine(
                        X, scales, zeros, axis, quant_min, quant_max)
                    self.assertEqual(
                        Y, Y_prime,
                        "Difference found between dequant+quant_per_channel and fake_quantize_per_channel"
                    )
                self.assertTrue(test_was_run)
Example #14
0
    def test_numerical_consistency_per_tensor(self, device, X):
        r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
        """
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        # quantize_per_tensor and dequantize are only implemented in CPU
        Y = torch.dequantize(torch.quantize_per_tensor(X.cpu(), scale, zero_point, torch_type))
        Y_prime = torch.fake_quantize_per_tensor_affine(
            X, scale, zero_point, quant_min, quant_max)
        np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
Example #15
0
def wav_to_float(x):
    """
    Input in range -2**15, 2**15 (or what is determined from dtype)
    Output in range -1, 1
    """
    assert x.dtype == torch.int16, f"got {x.dtype}"
    max_value = torch.iinfo(torch.int16).max
    min_value = torch.iinfo(torch.int16).min
    if not x.is_floating_point():
        x = x.to(torch.float)
    x = x - min_value
    x = x / ((max_value - min_value) / 2.0)
    x = x - 1.0
    return x
Example #16
0
def float_to_wav(x):
    """
    Input in range -1, 1
    Output in range -2**15, 2**15 (or what is determined from dtype)
    """
    assert x.dtype == torch.float
    max_value = torch.iinfo(torch.int16).max
    min_value = torch.iinfo(torch.int16).min

    x = x + 1.0
    x = x * (max_value - min_value) / 2.0
    x = x + min_value
    x = x.to(torch.int16)
    return x
Example #17
0
    def randomize_parameters(
        self,
        samples: Tensor = None,
        sample_rate: Optional[int] = None,
        targets: Optional[Tensor] = None,
        target_rate: Optional[int] = None,
    ):
        if self.shift_unit == "samples":
            min_shift_in_samples = self.min_shift
            max_shift_in_samples = self.max_shift

        elif self.shift_unit == "fraction":
            min_shift_in_samples = int(round(self.min_shift * samples.shape[-1]))
            max_shift_in_samples = int(round(self.max_shift * samples.shape[-1]))

        elif self.shift_unit == "seconds":
            min_shift_in_samples = int(round(self.min_shift * sample_rate))
            max_shift_in_samples = int(round(self.max_shift * sample_rate))

        else:
            raise ValueError("Invalid shift_unit")

        assert (
            torch.iinfo(torch.int32).min
            <= min_shift_in_samples
            <= torch.iinfo(torch.int32).max
        )
        assert (
            torch.iinfo(torch.int32).min
            <= max_shift_in_samples
            <= torch.iinfo(torch.int32).max
        )
        selected_batch_size = samples.size(0)
        if min_shift_in_samples == max_shift_in_samples:
            self.transform_parameters["num_samples_to_shift"] = torch.full(
                size=(selected_batch_size,),
                fill_value=min_shift_in_samples,
                dtype=torch.int32,
                device=samples.device,
            )

        else:
            self.transform_parameters["num_samples_to_shift"] = torch.randint(
                low=min_shift_in_samples,
                high=max_shift_in_samples + 1,
                size=(selected_batch_size,),
                dtype=torch.int32,
                device=samples.device,
            )
Example #18
0
 def get_castable_tensor(shape, dtype):
     if dtype.is_floating_point:
         dtype_info = torch.finfo(dtype)
         # can't directly use min and max, because for double, max - min
         # is greater than double range and sampling always gives inf.
         low = max(dtype_info.min, -1e10)
         high = min(dtype_info.max, 1e10)
         t = torch.empty(shape, dtype=torch.float64).uniform_(low, high)
     else:
         # can't directly use min and max, because for int64_t, max - min
         # is greater than int64_t range and triggers UB.
         low = max(torch.iinfo(dtype).min, int(-1e10))
         high = min(torch.iinfo(dtype).max, int(1e10))
         t = torch.empty(shape, dtype=torch.int64).random_(low, high)
     return t.to(dtype)
Example #19
0
 def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
              quant_min=0, quant_max=255):
     super(FakeQuantize, self).__init__()
     assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
     assert quant_min <= quant_max, \
         'quant_min must be less than or equal to quant_max'
     assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound'
     self.dtype = dtype
     self.qscheme = qscheme
     self.quant_min = quant_min
     self.quant_max = quant_max
     self.enabled = True
     self.observer = default_observer(dtype=dtype, qscheme=qscheme)
     self.scale = None
     self.zero_point = None
Example #20
0
def qtensor(draw, shapes, dtypes=None, float_min=None, float_max=None):
    # In case shape is a strategy
    if isinstance(shapes, SearchStrategy):
        shape = draw(shapes)
    else:
        shape = draw(st.sampled_from(shapes))
    # Resolve types
    if dtypes is None:
        dtypes = ALL_QINT_TYPES
    _dtypes = draw(st.sampled_from(dtypes))
    assert len(_dtypes) in [1, 2]
    if len(_dtypes) == 1:
        quantized_type = _dtypes[0]
        _zp_enforce = None
    elif len(_dtypes) == 2:
        quantized_type, _zp_enforce = _dtypes[:2]
    _qtype_info = torch.iinfo(quantized_type)
    qmin, qmax = _qtype_info.min, _qtype_info.max
    # Resolve zero_point
    if _zp_enforce is not None:
        zero_point = _zp_enforce
    else:
        zero_point = draw(st.integers(min_value=qmin, max_value=qmax))
    if float_min is None or float_max is None:
        _float_type_info = torch.finfo(torch.float)
        float_min = _float_type_info.min
        float_max = _float_type_info.max
    else:
        assert float_min <= float_max, 'float_min must be <= float_max'
    float_eps = _float_type_info.eps
    # Resolve scale
    scale = draw(st.floats(min_value=float_eps, max_value=float_max))

    adjustment = 1 + float_eps
    _long_type_info = torch.iinfo(torch.long)
    long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment
    # make sure intermediate results are within the range of long
    min_value = max((long_min - zero_point) * scale,
                    (long_min / scale + zero_point), float_min)
    max_value = min((long_max - zero_point) * scale,
                    (long_max / scale + zero_point), float_max)
    # Resolve the tensor
    Xhy = draw(
        stnp.arrays(dtype=np.float32,
                    elements=st.floats(min_value=min_value,
                                       max_value=max_value),
                    shape=shape))
    return Xhy, (scale, zero_point), (qmin, qmax), quantized_type
Example #21
0
    def __init__(self,
                 observer,
                 quant_min=0,
                 quant_max=255,
                 scale=1.,
                 zero_point=0.,
                 channel_len=-1,
                 use_grad_scaling=False,
                 **observer_kwargs):
        super(_LearnableFakeQuantize, self).__init__()
        assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
        self.quant_min = quant_min
        self.quant_max = quant_max
        # also pass quant_min and quant_max to observer
        observer_kwargs["quant_min"] = quant_min
        observer_kwargs["quant_max"] = quant_max
        self.use_grad_scaling = use_grad_scaling
        if channel_len == -1:
            self.scale = Parameter(torch.tensor([scale]))
            self.zero_point = Parameter(torch.tensor([zero_point]))
        else:
            assert isinstance(
                channel_len, int
            ) and channel_len > 0, "Channel size must be a positive integer."
            self.scale = Parameter(torch.tensor([scale] * channel_len))
            self.zero_point = Parameter(
                torch.tensor([zero_point] * channel_len))

        self.activation_post_process = observer(**observer_kwargs)
        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
            'quant_min out of bound'
        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
            'quant_max out of bound'
        self.dtype = self.activation_post_process.dtype
        self.qscheme = self.activation_post_process.qscheme
        self.ch_axis = self.activation_post_process.ch_axis \
            if hasattr(self.activation_post_process, 'ch_axis') else -1
        self.register_buffer('fake_quant_enabled',
                             torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('static_enabled',
                             torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('learning_enabled',
                             torch.tensor([0], dtype=torch.uint8))

        bitrange = torch.tensor(quant_max - quant_min + 1).double()
        self.bitwidth = int(torch.log2(bitrange).item())
        self.register_buffer('eps',
                             torch.tensor([torch.finfo(torch.float32).eps]))
Example #22
0
def qparams(draw,
            dtypes=None,
            scale_min=None,
            scale_max=None,
            zero_point_min=None,
            zero_point_max=None):
    if dtypes is None:
        dtypes = _ALL_QINT_TYPES
    if not isinstance(dtypes, (list, tuple)):
        dtypes = (dtypes, )
    quantized_type = draw(st.sampled_from(dtypes))

    _type_info = torch.iinfo(quantized_type)
    qmin, qmax = _type_info.min, _type_info.max

    # TODO: Maybe embed the enforced zero_point in the `torch.iinfo`.
    _zp_enforced = _ENFORCED_ZERO_POINT[quantized_type]
    if _zp_enforced is not None:
        zero_point = _zp_enforced
    else:
        _zp_min = qmin if zero_point_min is None else zero_point_min
        _zp_max = qmax if zero_point_max is None else zero_point_max
        zero_point = draw(st.integers(min_value=_zp_min, max_value=_zp_max))

    if scale_min is None:
        scale_min = torch.finfo(torch.float).eps
    if scale_max is None:
        scale_max = torch.finfo(torch.float).max
    scale = draw(floats(min_value=scale_min, max_value=scale_max, width=32))

    return scale, zero_point, quantized_type
Example #23
0
 def sizeof(dtype):
     if dtype == torch.bool:
         return 1
     elif dtype.is_floating_point:
         return torch.finfo(dtype).bits // 8
     else:
         return torch.iinfo(dtype).bits // 8
Example #24
0
 def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
     super(FakeQuantize, self).__init__()
     assert quant_min <= quant_max, \
         'quant_min must be less than or equal to quant_max'
     self.quant_min = quant_min
     self.quant_max = quant_max
     self.fake_quant_enabled = True
     self.observer_enabled = True
     self.activation_post_process = observer(**observer_kwargs)
     assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
     assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
     self.scale = torch.tensor([1.0])
     self.zero_point = torch.tensor([0])
     self.dtype = self.activation_post_process.dtype
     self.qscheme = self.activation_post_process.qscheme
     self.ch_axis = self.activation_post_process.ch_axis if hasattr(self.activation_post_process, 'ch_axis') else None
Example #25
0
def _argmax_helper_pairwise(enc_tensor, dim=None):
    """Returns 1 for all elements that have the highest value in the appropriate
    dimension of the tensor. Uses O(n^2) comparisons and a constant number of
    rounds of communication
    """
    dim = -1 if dim is None else dim
    row_length = enc_tensor.size(dim) if enc_tensor.size(dim) > 1 else 2

    # Copy each row (length - 1) times to compare to each other row
    a = enc_tensor.expand(row_length - 1, *enc_tensor.size())

    # Generate cyclic permutations for each row
    b = crypten.stack(
        [enc_tensor.roll(i + 1, dims=dim) for i in range(row_length - 1)])

    # Use either prod or sum & comparison depending on size
    if row_length - 1 < torch.iinfo(torch.long).bits * 2:
        pairwise_comparisons = a.ge(b, _scale=False)
        result = pairwise_comparisons.prod(0)
        result.share *= enc_tensor.encoder._scale
        result.encoder = enc_tensor.encoder
    else:
        # Sum of columns with all 1s will have value equal to (length - 1).
        # Using ge() since it is slightly faster than eq()
        pairwise_comparisons = a.ge(b)
        result = pairwise_comparisons.sum(0).ge(row_length - 1)
    return result, None
Example #26
0
 def __init__(self, observer=MinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
     super(FakeQuantize, self).__init__()
     assert quant_min <= quant_max, \
         'quant_min must be less than or equal to quant_max'
     self.quant_min = quant_min
     self.quant_max = quant_max
     self.fake_quant_enabled = True
     self.observer_enabled = True
     self.observer = observer(**observer_kwargs)
     assert torch.iinfo(self.observer.dtype).min <= quant_min, 'quant_min out of bound'
     assert quant_max <= torch.iinfo(self.observer.dtype).max, 'quant_max out of bound'
     self.scale = None
     self.zero_point = None
     self.dtype = self.observer.dtype
     self.qscheme = self.observer.qscheme
     self.ch_axis = self.observer.ch_axis if hasattr(self.observer, 'ch_axis') else 0
Example #27
0
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
    Argument may be a filename, compressed filename, or file object.
    """
    # read
    with open(path, "rb") as f:
        data = f.read()
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    ty = magic // 256
    assert 1 <= nd <= 3
    assert 8 <= ty <= 14
    torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]

    num_bytes_per_value = torch.iinfo(torch_type).bits // 8
    # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
    # we need to reverse the bytes before we can read them with torch.frombuffer().
    needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
    parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
    if needs_byte_reversal:
        parsed = parsed.flip(0)

    assert parsed.shape[0] == np.prod(s) or not strict
    return parsed.view(*s)
Example #28
0
def _B2A(binary_tensor, precision=None, bits=None):
    if bits is None:
        bits = torch.iinfo(torch.long).bits

    if bits == 1:
        binary_bit = binary_tensor & 1
        arithmetic_tensor = beaver.B2A_single_bit(binary_bit)
    else:
        binary_bits = BinarySharedTensor.stack(
            [binary_tensor >> i for i in range(bits)])
        binary_bits = binary_bits & 1
        arithmetic_bits = beaver.B2A_single_bit(binary_bits)

        multiplier = torch.cat([
            torch.tensor([1], dtype=torch.long, device=binary_tensor.device) <<
            i for i in range(bits)
        ])
        while multiplier.dim() < arithmetic_bits.dim():
            multiplier = multiplier.unsqueeze(1)

        arithmetic_tensor = arithmetic_bits.mul_(multiplier).sum(0)

    arithmetic_tensor.encoder = FixedPointEncoder(precision_bits=precision)
    scale = arithmetic_tensor.encoder._scale // binary_tensor.encoder._scale
    arithmetic_tensor *= scale
    return arithmetic_tensor
Example #29
0
 def _case_zero_transform(t):
     try:
         info = torch.iinfo(t.dtype)
         return torch.full_like(t, info.max)
     except TypeError as te:
         # for non-integer types fills with NaN
         return torch.full_like(t, float('nan'))
Example #30
0
 def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
              quant_min=0, quant_max=255, reduce_range=False):
     super(FakeQuantize, self).__init__()
     assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
     assert quant_min <= quant_max, \
         'quant_min must be less than or equal to quant_max'
     assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound'
     self.dtype = dtype
     self.qscheme = qscheme
     self.quant_min = quant_min
     self.quant_max = quant_max
     self.fake_quant_enabled = True
     self.observer_enabled = True
     self.observer = MinMaxObserver.with_args(dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)()
     self.scale = None
     self.zero_point = None