def test_scale_invariance_signed_activation_quantization(self, prec): # Scaling activation by power of 2 and bounds by same factor, # should scale the output by the same scale. activations = random.uniform(random.PRNGKey(0), (10, 1)) act_scale = 8. scaled_activations = activations * act_scale bounds = 6. activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) scaled_activations = QuantOps.create_inputs_fake_quant( inputs=scaled_activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds * act_scale, prec=prec)) onp.testing.assert_array_equal(activations * act_scale, scaled_activations)
def __call__(self, lhs, rhs): lhs_get_bounds = GetBounds.Hyper( initial_bound=10.0, stddev_coeff=0, absdev_coeff=0, mix_coeff=0, granularity=quant_config.QuantGranularity.per_tensor) rhs_get_bounds = GetBounds.Hyper( initial_bound=5.0, stddev_coeff=0, absdev_coeff=0, mix_coeff=0, granularity=quant_config.QuantGranularity.per_tensor) lhs_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=lhs_get_bounds, prec=8) rhs_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=rhs_get_bounds, prec=8) lhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=True, update_bounds=False, module_name='lhs') rhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=True, update_bounds=False, module_name='rhs') out = quantization.quantized_dynamic_dot_general( lhs_act=lhs, rhs_act=rhs, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=QuantType.aqt, lhs_get_bounds_params=lhs_get_bounds_params, rhs_get_bounds_params=rhs_get_bounds_params) return out
def __call__(self, inputs): return QuantOps.create_input_ops( inputs, hparams=hparams, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False))
def test_int_symmetric_act_quantization(self, prec): # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1 # quantizes correctly. bounds = 2**(prec - 1) - 1 activations = random.randint(random.PRNGKey(0), (10, 1), -bounds, bounds) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations)
def test_int_positive_act_quantization(self, prec): # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec # quantizes correctly. upper_bound = 2**(prec - 3) activations = random.randint(random.PRNGKey(0), (10, 1), 0, upper_bound) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.positive, bounds=upper_bound, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations)
def test_quantized_dynamic_dot_general_should_call_inputs_quantization( self, mock_act_fq, lhs_act_prec, rhs_act_prec, strategy=QuantType.fake_quant): mock_act_fq.side_effect = lambda inputs, hparams, get_bounds_params: inputs # pylint: disable=g-long-ternary lhs_act_hparams = QuantOps.ActHParams( bounds=6., prec=lhs_act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if lhs_act_prec else None rhs_act_hparams = QuantOps.ActHParams( bounds=6., prec=rhs_act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if rhs_act_prec else None # pylint: enable=g-long-ternary get_bounds_params = GetBounds.Params(update_stats=False, update_bounds=False) quantization.quantized_dynamic_dot_general( lhs_act=self.lhs_act, rhs_act=self.rhs_act, quant_type=strategy, dot_dimension_numbers=self.dimension_numbers, lhs_act_hparams=lhs_act_hparams, lhs_get_bounds_params=get_bounds_params, rhs_act_hparams=rhs_act_hparams, rhs_get_bounds_params=get_bounds_params, ) calls = [] for prec in [lhs_act_prec, rhs_act_prec]: if prec is not None: act_hparams = QuantOps.ActHParams(bounds=6., prec=prec, input_distribution=mock.ANY, half_shift=False) calls.append( mock.call(mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)) self.assertLen(calls, mock_act_fq.call_count) mock_act_fq.assert_has_calls(calls, any_order=True)
def test_inputs_scale_shape_is_expected(self): # Inputs quantization inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) bounds = 6.0 expected_inputs_scale_shape = () _ = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=8.0), get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False, expected_bounds_shape=expected_inputs_scale_shape))
def test_quantized_dot_general_should_call_weights_and_inputs_quantization( self, mock_act_fq, mock_w_fq, weight_prec, act_prec, strategy=QuantType.fake_quant): mock_w_fq.side_effect = lambda inputs, **_: inputs mock_act_fq.side_effect = lambda inputs, **_: inputs weight_params = QuantOps.WeightParams( prec=weight_prec, axis=None, half_shift=False) act_hparams = QuantOps.ActHParams( # pylint: disable=g-long-ternary bounds=6., prec=act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if act_prec else None get_bounds_params = GetBounds.Params( update_stats=False, update_bounds=False) quantization.quantized_dot( w=self.weight, act=self.act, quant_type=strategy, weight_params=weight_params, act_hparams=act_hparams, get_bounds_params=get_bounds_params, prefer_int8_to_int32_dot=True) quantized_type = strategy.to_jax_type() mock_w_fq.assert_called_with( mock.ANY, weight_params=weight_params, quantized_type=quantized_type, fake_dependency=mock.ANY) if act_hparams: mock_act_fq.assert_called_with( mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params) else: mock_act_fq.assert_not_called()
def __call__(self, inputs): return quantization.QuantOps.create_inputs_fake_quant( inputs, hparams=hparams, get_bounds_params=GetBounds.Params(update_stats=True, update_bounds=False))