def _calculate_dynamic_per_channel_qparams(X, dtype): """Calculate the dynamic quantization parameters (scale, zero_point) according to the min and max element of the tensor""" if isinstance(X, torch.Tensor): X = X.numpy() qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max n_levels = qmax - qmin scale = np.zeros(X.shape[0], dtype=np.float64) zero_point = np.zeros(X.shape[0], dtype=np.int64) for i in range(zero_point.shape[0]): min_val = X.min() max_val = X.max() if min_val == max_val: scale[i] = 1.0 zero_point[i] = 0 else: max_val = max(max_val, 0.0) min_val = min(min_val, 0.0) scale[i] = (max_val - min_val) / n_levels scale[i] = max(scale[i], np.finfo(np.float32).eps) zero_point[i] = qmin - round(min_val / scale[i]) zero_point[i] = max(qmin, zero_point[i]) zero_point[i] = min(qmax, zero_point[i]) return scale, zero_point
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): super(FakeQuantize, self).__init__() assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' self.quant_min = quant_min self.quant_max = quant_max # fake_quant_enabled and observer_enabled are buffers to support their # replication in DDP. Data type is uint8 because NCCL does not support # bool tensors. self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8)) self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype ).min <= quant_min, 'quant_min out of bound' assert quant_max <= torch.iinfo( self.activation_post_process.dtype).max, 'quant_max out of bound' self.register_buffer('scale', torch.tensor([1.0])) self.register_buffer('zero_point', torch.tensor([0])) self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = self.activation_post_process.ch_axis \ if hasattr(self.activation_post_process, 'ch_axis') else -1
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): super().__init__() assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' self.quant_min = quant_min self.quant_max = quant_max self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype ).min <= quant_min, 'quant_min out of bound' assert quant_max <= torch.iinfo( self.activation_post_process.dtype).max, 'quant_max out of bound' self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int)) self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = self.activation_post_process.ch_axis \ if hasattr(self.activation_post_process, 'ch_axis') else -1 assert _is_per_channel(self.qscheme) or \ _is_per_tensor(self.qscheme), \ 'Only per channel and per tensor quantization are supported in fake quantize' + \ ' got qscheme: ' + str(self.qscheme) self.is_per_channel = _is_per_channel(self.qscheme)
def test_fq_module(self, device, X): np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, axis, torch_type) = X quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max X = to_tensor(X, device) X.requires_grad_() fq_module = FakeQuantize(default_per_channel_weight_observer, quant_min, quant_max, ch_axis=axis).to(device) Y_prime = fq_module(X) assert fq_module.scale is not None assert fq_module.zero_point is not None Y = _fake_quantize_per_channel_affine_reference( X, fq_module.scale, fq_module.zero_point, axis, quant_min, quant_max) np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) # Test backward dout = torch.rand(X.shape, dtype=torch.float, device=device) Y_prime.backward(dout) dX = _fake_quantize_per_channel_affine_grad_reference( dout, X, fq_module.scale, fq_module.zero_point, axis, quant_min, quant_max) np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """Convert a tensor image to the given ``dtype`` and scale the values accordingly Args: image (torch.Tensor): Image to be converted dtype (torch.dtype): Desired data type of the output Returns: (torch.Tensor): Converted image .. note:: When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. If converted back and forth, this mismatch has no effect. Raises: RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range of the integer ``dtype``. """ if image.dtype == dtype: return image if image.dtype.is_floating_point: # float to float if dtype.is_floating_point: return image.to(dtype) # float to int if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (image.dtype == torch.float64 and dtype == torch.int64): msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." raise RuntimeError(msg) eps = 1e-3 return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) else: # int to float if dtype.is_floating_point: max = torch.iinfo(image.dtype).max image = image.to(dtype) return image / max # int to int input_max = torch.iinfo(image.dtype).max output_max = torch.iinfo(dtype).max if input_max > output_max: factor = (input_max + 1) // (output_max + 1) image = image // factor return image.to(dtype) else: factor = (output_max + 1) // (input_max + 1) image = image.to(dtype) return image * factor
def test_stable_sort_against_numpy(self, device, dtype): if dtype in torch.testing.floating_types_and(torch.float16): inf = float('inf') neg_inf = -float('inf') nan = float('nan') else: if dtype != torch.bool: # no torch.iinfo support for torch.bool inf = torch.iinfo(dtype).max neg_inf = torch.iinfo(dtype).min else: inf = True neg_inf = ~inf # no nan for integral types, we use inf instead for simplicity nan = inf def generate_samples(): from itertools import chain, combinations def repeated_index_fill(t, dim, idxs, vals): res = t for idx, val in zip(idxs, vals): res = res.index_fill(dim, idx, val) return res for sizes in [(1, 10), (10, 1), (10, 10), (10, 10, 10)]: size = min(*sizes) x = (torch.randn(*sizes, device=device) * size).to(dtype) yield (x, 0) # Generate tensors which are being filled at random locations # with values from the non-empty subsets of the set (inf, neg_inf, nan) # for each dimension. n_fill_vals = 3 # cardinality of (inf, neg_inf, nan) for dim in range(len(sizes)): idxs = (torch.randint(high=size, size=(size // 10, )) for i in range(n_fill_vals)) vals = (inf, neg_inf, nan) subsets = chain.from_iterable( combinations(list(zip(idxs, vals)), r) for r in range(1, n_fill_vals + 1)) for subset in subsets: idxs_subset, vals_subset = zip(*subset) yield (repeated_index_fill(x, dim, idxs_subset, vals_subset), dim) for sizes in [(100, ), (1000, ), (10000, )]: size = sizes[0] # binary strings yield (torch.tensor([0, 1] * size, dtype=dtype, device=device), 0) for sample, dim in generate_samples(): _, idx_torch = sample.sort(dim=dim, stable=True) sample_numpy = sample.numpy() idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable') self.assertEqual(idx_torch, idx_numpy)
def test_topk_integral(self, device, dtype): a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, size=(10, ), dtype=dtype, device=device) sort_topk = a.sort()[0][-5:].flip(0) topk = a.topk(5) self.assertEqual(sort_topk, topk[0]) # check values self.assertEqual(sort_topk, a[topk[1]]) # check indices
def _test_topk_dtype(self, device, dtype, integral, size): if integral: a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, size=(size,), dtype=dtype, device=device) else: a = torch.randn(size=(size,), dtype=dtype, device=device) sort_topk = a.sort()[0][-(size // 2):].flip(0) topk = a.topk(size // 2) self.assertEqual(sort_topk, topk[0]) # check values self.assertEqual(sort_topk, a[topk[1]]) # check indices
def test_topk_integral(self, device, dtype): small = 10 large = 4096 for curr_size in (small, large): a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, size=(curr_size, ), dtype=dtype, device=device) sort_topk = a.sort()[0][-(curr_size // 2):].flip(0) topk = a.topk(curr_size // 2) self.assertEqual(sort_topk, topk[0]) # check values self.assertEqual(sort_topk, a[topk[1]]) # check indices
def test_forward_per_tensor(self, device, X): r"""Tests the forward path of the FakeQuantizePerTensorAffine op. """ np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, torch_type) = X quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max X = to_tensor(X, device) Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max) Y_prime = torch.fake_quantize_per_tensor_affine( X, scale, zero_point, quant_min, quant_max) np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
def __init__(self, observer, quant_min=-128, quant_max=127, scale=1., zero_point=0., **observer_kwargs): """Adaround FakeQuantization module Args: observer (torch.quantization.observer): quantization obserever to initilalize quantizers quant_min (int, optional): quantized values range (min). Defaults to -128. quant_max (int, optional): quantized values range (max). Defaults to 127. scale (float, optional): Defaults to 1.. zero_point (float, optional): Defaults to 0.. """ super(AdaRound, self).__init__() assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.' self.quant_min = quant_min self.quant_max = quant_max self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ 'quant_min out of bound' assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ 'quant_max out of bound' self.register_buffer('scale', torch.tensor([scale])) self.register_buffer('zero_point', torch.tensor([zero_point])) self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme assert self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric) self.ch_axis = self.activation_post_process.ch_axis \ if hasattr(self.activation_post_process, 'ch_axis') else -1 self.register_buffer('observer_enabeled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('soft_sigmoid_enabled', torch.tensor([0], dtype=torch.uint8)) self.register_buffer('hard_sigmoid_enabled', torch.tensor([0], dtype=torch.uint8)) bitrange = torch.tensor(quant_max - quant_min + 1).double() self.bitwidth = int(torch.log2(bitrange).item()) self.gamma, self.zeta = -0.1, 1.1 self.beta = 2. / 3
def _reduction_identity(op_name: str, input: Tensor, *args): """Return identity value as scalar tensor of a reduction operation on given input, or None, if the identity value cannot be uniquely defined for the given input. The identity value of the operation is defined as the initial value to reduction operation that has a property ``op(op_identity, value) == value`` for any value in the domain of the operation. Or put it another way, including or exlucing the identity value in a list of operands will not change the reduction result. See https://github.com/pytorch/rfcs/pull/27 for more information. """ dtype: DType = input.dtype device = input.device op_name = op_name.rsplit('.', 1)[-1] # lstrip module name when present if op_name == 'sum': return torch.tensor(0, dtype=dtype, device=device) elif op_name == 'prod': return torch.tensor(1, dtype=dtype, device=device) elif op_name == 'amax': if torch.is_floating_point(input): return torch.tensor(-torch.inf, dtype=dtype, device=device) elif torch.is_signed(input) or dtype == torch.uint8: return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) elif op_name == 'amin': if torch.is_floating_point(input): return torch.tensor(torch.inf, dtype=dtype, device=device) elif torch.is_signed(input) or dtype == torch.uint8: return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) elif op_name == 'mean': # Strictly speaking, the identity value of the mean operation # is the mean of the input. Since the mean value depends on # the dim argument and it may be a non-scalar tensor, we # consider the identity value of the mean operation ambiguous. # Moreover, the mean value of empty input is undefined. return None elif op_name == 'norm': ord = args[0] if args else 2 if ord == float('-inf'): assert torch.is_floating_point(input), input.dtype return torch.tensor(torch.inf, dtype=dtype, device=device) return torch.tensor(0, dtype=dtype, device=device) elif op_name in {'var', 'std'}: return None raise NotImplementedError(f'identity of {op_name} on {dtype} input')
def _test_numerical_consistency(self, test_type): r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes """ torch.random.manual_seed(NP_RANDOM_SEED) torch_types = [torch.qint8, torch.quint8] float_types = [torch.float, torch.float16, torch.float64] zero_types = [torch.long] devices = [torch.device('cpu'), torch.device('cuda') ] if torch.cuda.is_available() else [torch.device('cpu')] axis = 1 for i in range(20): for torch_type, float_type, device, zero_type in itertools.product( torch_types, float_types, devices, zero_types): X = torch.randn(3, 3, device=device).to(float_type) scales = (10 * torch.randn(3, device=device)).abs() scale = scales.mean().to(float).item() zeros = (10 * torch.randn(3, device=device)).abs().to( dtype=zero_type) zero = zeros.max().view(1).item() quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max test_was_run = False if test_type == "per_tensor": test_was_run = True Y = torch.dequantize( torch.quantize_per_tensor( X.to('cpu').to(torch.float), scale, zero, torch_type)).to(device).to(float_type) Y_prime = torch.fake_quantize_per_tensor_affine( X, scale, zero, quant_min, quant_max) self.assertEqual( Y, Y_prime, "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor" ) if test_type == "per_channel": test_was_run = True Y = torch.dequantize( torch.quantize_per_channel( X.to('cpu').to(torch.float), scales.to('cpu'), zeros.to('cpu'), axis, torch_type)).to(device).to(float_type) Y_prime = torch.fake_quantize_per_channel_affine( X, scales, zeros, axis, quant_min, quant_max) self.assertEqual( Y, Y_prime, "Difference found between dequant+quant_per_channel and fake_quantize_per_channel" ) self.assertTrue(test_was_run)
def test_numerical_consistency_per_tensor(self, device, X): r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op """ np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, torch_type) = X quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max X = to_tensor(X, device) # quantize_per_tensor and dequantize are only implemented in CPU Y = torch.dequantize(torch.quantize_per_tensor(X.cpu(), scale, zero_point, torch_type)) Y_prime = torch.fake_quantize_per_tensor_affine( X, scale, zero_point, quant_min, quant_max) np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
def wav_to_float(x): """ Input in range -2**15, 2**15 (or what is determined from dtype) Output in range -1, 1 """ assert x.dtype == torch.int16, f"got {x.dtype}" max_value = torch.iinfo(torch.int16).max min_value = torch.iinfo(torch.int16).min if not x.is_floating_point(): x = x.to(torch.float) x = x - min_value x = x / ((max_value - min_value) / 2.0) x = x - 1.0 return x
def float_to_wav(x): """ Input in range -1, 1 Output in range -2**15, 2**15 (or what is determined from dtype) """ assert x.dtype == torch.float max_value = torch.iinfo(torch.int16).max min_value = torch.iinfo(torch.int16).min x = x + 1.0 x = x * (max_value - min_value) / 2.0 x = x + min_value x = x.to(torch.int16) return x
def randomize_parameters( self, samples: Tensor = None, sample_rate: Optional[int] = None, targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): if self.shift_unit == "samples": min_shift_in_samples = self.min_shift max_shift_in_samples = self.max_shift elif self.shift_unit == "fraction": min_shift_in_samples = int(round(self.min_shift * samples.shape[-1])) max_shift_in_samples = int(round(self.max_shift * samples.shape[-1])) elif self.shift_unit == "seconds": min_shift_in_samples = int(round(self.min_shift * sample_rate)) max_shift_in_samples = int(round(self.max_shift * sample_rate)) else: raise ValueError("Invalid shift_unit") assert ( torch.iinfo(torch.int32).min <= min_shift_in_samples <= torch.iinfo(torch.int32).max ) assert ( torch.iinfo(torch.int32).min <= max_shift_in_samples <= torch.iinfo(torch.int32).max ) selected_batch_size = samples.size(0) if min_shift_in_samples == max_shift_in_samples: self.transform_parameters["num_samples_to_shift"] = torch.full( size=(selected_batch_size,), fill_value=min_shift_in_samples, dtype=torch.int32, device=samples.device, ) else: self.transform_parameters["num_samples_to_shift"] = torch.randint( low=min_shift_in_samples, high=max_shift_in_samples + 1, size=(selected_batch_size,), dtype=torch.int32, device=samples.device, )
def get_castable_tensor(shape, dtype): if dtype.is_floating_point: dtype_info = torch.finfo(dtype) # can't directly use min and max, because for double, max - min # is greater than double range and sampling always gives inf. low = max(dtype_info.min, -1e10) high = min(dtype_info.max, 1e10) t = torch.empty(shape, dtype=torch.float64).uniform_(low, high) else: # can't directly use min and max, because for int64_t, max - min # is greater than int64_t range and triggers UB. low = max(torch.iinfo(dtype).min, int(-1e10)) high = min(torch.iinfo(dtype).max, int(1e10)) t = torch.empty(shape, dtype=torch.int64).random_(low, high) return t.to(dtype)
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=0, quant_max=255): super(FakeQuantize, self).__init__() assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound' assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound' self.dtype = dtype self.qscheme = qscheme self.quant_min = quant_min self.quant_max = quant_max self.enabled = True self.observer = default_observer(dtype=dtype, qscheme=qscheme) self.scale = None self.zero_point = None
def qtensor(draw, shapes, dtypes=None, float_min=None, float_max=None): # In case shape is a strategy if isinstance(shapes, SearchStrategy): shape = draw(shapes) else: shape = draw(st.sampled_from(shapes)) # Resolve types if dtypes is None: dtypes = ALL_QINT_TYPES _dtypes = draw(st.sampled_from(dtypes)) assert len(_dtypes) in [1, 2] if len(_dtypes) == 1: quantized_type = _dtypes[0] _zp_enforce = None elif len(_dtypes) == 2: quantized_type, _zp_enforce = _dtypes[:2] _qtype_info = torch.iinfo(quantized_type) qmin, qmax = _qtype_info.min, _qtype_info.max # Resolve zero_point if _zp_enforce is not None: zero_point = _zp_enforce else: zero_point = draw(st.integers(min_value=qmin, max_value=qmax)) if float_min is None or float_max is None: _float_type_info = torch.finfo(torch.float) float_min = _float_type_info.min float_max = _float_type_info.max else: assert float_min <= float_max, 'float_min must be <= float_max' float_eps = _float_type_info.eps # Resolve scale scale = draw(st.floats(min_value=float_eps, max_value=float_max)) adjustment = 1 + float_eps _long_type_info = torch.iinfo(torch.long) long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment # make sure intermediate results are within the range of long min_value = max((long_min - zero_point) * scale, (long_min / scale + zero_point), float_min) max_value = min((long_max - zero_point) * scale, (long_max / scale + zero_point), float_max) # Resolve the tensor Xhy = draw( stnp.arrays(dtype=np.float32, elements=st.floats(min_value=min_value, max_value=max_value), shape=shape)) return Xhy, (scale, zero_point), (qmin, qmax), quantized_type
def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1, use_grad_scaling=False, **observer_kwargs): super(_LearnableFakeQuantize, self).__init__() assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.' self.quant_min = quant_min self.quant_max = quant_max # also pass quant_min and quant_max to observer observer_kwargs["quant_min"] = quant_min observer_kwargs["quant_max"] = quant_max self.use_grad_scaling = use_grad_scaling if channel_len == -1: self.scale = Parameter(torch.tensor([scale])) self.zero_point = Parameter(torch.tensor([zero_point])) else: assert isinstance( channel_len, int ) and channel_len > 0, "Channel size must be a positive integer." self.scale = Parameter(torch.tensor([scale] * channel_len)) self.zero_point = Parameter( torch.tensor([zero_point] * channel_len)) self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ 'quant_min out of bound' assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ 'quant_max out of bound' self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = self.activation_post_process.ch_axis \ if hasattr(self.activation_post_process, 'ch_axis') else -1 self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8)) bitrange = torch.tensor(quant_max - quant_min + 1).double() self.bitwidth = int(torch.log2(bitrange).item()) self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
def qparams(draw, dtypes=None, scale_min=None, scale_max=None, zero_point_min=None, zero_point_max=None): if dtypes is None: dtypes = _ALL_QINT_TYPES if not isinstance(dtypes, (list, tuple)): dtypes = (dtypes, ) quantized_type = draw(st.sampled_from(dtypes)) _type_info = torch.iinfo(quantized_type) qmin, qmax = _type_info.min, _type_info.max # TODO: Maybe embed the enforced zero_point in the `torch.iinfo`. _zp_enforced = _ENFORCED_ZERO_POINT[quantized_type] if _zp_enforced is not None: zero_point = _zp_enforced else: _zp_min = qmin if zero_point_min is None else zero_point_min _zp_max = qmax if zero_point_max is None else zero_point_max zero_point = draw(st.integers(min_value=_zp_min, max_value=_zp_max)) if scale_min is None: scale_min = torch.finfo(torch.float).eps if scale_max is None: scale_max = torch.finfo(torch.float).max scale = draw(floats(min_value=scale_min, max_value=scale_max, width=32)) return scale, zero_point, quantized_type
def sizeof(dtype): if dtype == torch.bool: return 1 elif dtype.is_floating_point: return torch.finfo(dtype).bits // 8 else: return torch.iinfo(dtype).bits // 8
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): super(FakeQuantize, self).__init__() assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' self.quant_min = quant_min self.quant_max = quant_max self.fake_quant_enabled = True self.observer_enabled = True self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound' assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound' self.scale = torch.tensor([1.0]) self.zero_point = torch.tensor([0]) self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = self.activation_post_process.ch_axis if hasattr(self.activation_post_process, 'ch_axis') else None
def _argmax_helper_pairwise(enc_tensor, dim=None): """Returns 1 for all elements that have the highest value in the appropriate dimension of the tensor. Uses O(n^2) comparisons and a constant number of rounds of communication """ dim = -1 if dim is None else dim row_length = enc_tensor.size(dim) if enc_tensor.size(dim) > 1 else 2 # Copy each row (length - 1) times to compare to each other row a = enc_tensor.expand(row_length - 1, *enc_tensor.size()) # Generate cyclic permutations for each row b = crypten.stack( [enc_tensor.roll(i + 1, dims=dim) for i in range(row_length - 1)]) # Use either prod or sum & comparison depending on size if row_length - 1 < torch.iinfo(torch.long).bits * 2: pairwise_comparisons = a.ge(b, _scale=False) result = pairwise_comparisons.prod(0) result.share *= enc_tensor.encoder._scale result.encoder = enc_tensor.encoder else: # Sum of columns with all 1s will have value equal to (length - 1). # Using ge() since it is slightly faster than eq() pairwise_comparisons = a.ge(b) result = pairwise_comparisons.sum(0).ge(row_length - 1) return result, None
def __init__(self, observer=MinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): super(FakeQuantize, self).__init__() assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' self.quant_min = quant_min self.quant_max = quant_max self.fake_quant_enabled = True self.observer_enabled = True self.observer = observer(**observer_kwargs) assert torch.iinfo(self.observer.dtype).min <= quant_min, 'quant_min out of bound' assert quant_max <= torch.iinfo(self.observer.dtype).max, 'quant_max out of bound' self.scale = None self.zero_point = None self.dtype = self.observer.dtype self.qscheme = self.observer.qscheme self.ch_axis = self.observer.ch_axis if hasattr(self.observer, 'ch_axis') else 0
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). Argument may be a filename, compressed filename, or file object. """ # read with open(path, "rb") as f: data = f.read() # parse magic = get_int(data[0:4]) nd = magic % 256 ty = magic // 256 assert 1 <= nd <= 3 assert 8 <= ty <= 14 torch_type = SN3_PASCALVINCENT_TYPEMAP[ty] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] num_bytes_per_value = torch.iinfo(torch_type).bits // 8 # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, # we need to reverse the bytes before we can read them with torch.frombuffer(). needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) if needs_byte_reversal: parsed = parsed.flip(0) assert parsed.shape[0] == np.prod(s) or not strict return parsed.view(*s)
def _B2A(binary_tensor, precision=None, bits=None): if bits is None: bits = torch.iinfo(torch.long).bits if bits == 1: binary_bit = binary_tensor & 1 arithmetic_tensor = beaver.B2A_single_bit(binary_bit) else: binary_bits = BinarySharedTensor.stack( [binary_tensor >> i for i in range(bits)]) binary_bits = binary_bits & 1 arithmetic_bits = beaver.B2A_single_bit(binary_bits) multiplier = torch.cat([ torch.tensor([1], dtype=torch.long, device=binary_tensor.device) << i for i in range(bits) ]) while multiplier.dim() < arithmetic_bits.dim(): multiplier = multiplier.unsqueeze(1) arithmetic_tensor = arithmetic_bits.mul_(multiplier).sum(0) arithmetic_tensor.encoder = FixedPointEncoder(precision_bits=precision) scale = arithmetic_tensor.encoder._scale // binary_tensor.encoder._scale arithmetic_tensor *= scale return arithmetic_tensor
def _case_zero_transform(t): try: info = torch.iinfo(t.dtype) return torch.full_like(t, info.max) except TypeError as te: # for non-integer types fills with NaN return torch.full_like(t, float('nan'))
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, quant_min=0, quant_max=255, reduce_range=False): super(FakeQuantize, self).__init__() assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound' assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound' self.dtype = dtype self.qscheme = qscheme self.quant_min = quant_min self.quant_max = quant_max self.fake_quant_enabled = True self.observer_enabled = True self.observer = MinMaxObserver.with_args(dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)() self.scale = None self.zero_point = None