def test_quantized_dynamic_dot_general(self, lhs_prec, rhs_prec): lhs_bounds = 2.0 rhs_bounds = 1.5 lhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=lhs_bounds, prec=lhs_prec) rhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=rhs_bounds, prec=rhs_prec) def quantized_matmul(quant_type): return quantization.quantized_dynamic_dot_general( lhs_act=self.lhs, rhs_act=self.rhs, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=quant_type) aqt_result = quantized_matmul(QuantType.aqt) fakequant_result = quantized_matmul(QuantType.fake_quant) onp.testing.assert_allclose( aqt_result, fakequant_result, rtol=1e-2, err_msg='AQT and fakequant significantly disagree')
def __call__(self, lhs_act, rhs_act, lhs_prec, rhs_prec): get_bounds_hyper = get_bounds.GetBounds.Hyper( initial_bound=10.0, stddev_coeff=0, absdev_coeff=0, mix_coeff=0, granularity=quant_config.QuantGranularity.per_tensor) lhs_act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=get_bounds_hyper, prec=lhs_prec, half_shift=False) rhs_act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=get_bounds_hyper, prec=rhs_prec, half_shift=False) lhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=False, update_bounds=False, module_name='lhs') rhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=False, update_bounds=False, module_name='rhs') output = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_act_hparams, rhs_act_hparams=rhs_act_hparams, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=QuantType.aqt, lhs_get_bounds_params=lhs_get_bounds_params, rhs_get_bounds_params=rhs_get_bounds_params) return output
def test_quantized_dot_aqt(self, act_bounds, weight_prec, weight_axis): # With a high enough precision, we expect results from fakequant and AQT to # be very similar. weight_params = QuantOps.WeightParams(prec=weight_prec, axis=weight_axis) if act_bounds is None: act_params = None else: act_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=jnp.array(act_bounds), prec=16) def quantized_matmul(quant_type): return quantization.quantized_dot(w=self.rhs, act=self.lhs, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True) aqt_result = quantized_matmul(QuantType.aqt) fakequant_result = quantized_matmul(QuantType.fake_quant) onp.testing.assert_allclose( aqt_result, fakequant_result, rtol=1e-2, err_msg='AQT and fakequant significantly disagree')
def test_lax_dot_has_integer_inputs_in_dynamic_dot_general( self, mock_dot_general, lhs_distribution, rhs_distribution): lhs_params = QuantOps.ActHParams(input_distribution=lhs_distribution, bounds=2.0, prec=4) rhs_params = QuantOps.ActHParams(input_distribution=rhs_distribution, bounds=1.5, prec=4) lhs_act = self.lhs if lhs_distribution == 'positive': lhs_act = jnp.abs(lhs_act) rhs_act = self.rhs if rhs_distribution == 'positive': rhs_act = jnp.abs(rhs_act) quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=QuantType.aqt) lhs_inputs, rhs_inputs = mock_dot_general.call_args[0] self.assert_is_integer_in_range(lhs_inputs, prec=4, distribution=lhs_distribution) self.assert_is_integer_in_range(rhs_inputs, prec=4, distribution=rhs_distribution)
def test_scale_invariance_signed_activation_quantization(self, prec): # Scaling activation by power of 2 and bounds by same factor, # should scale the output by the same scale. activations = random.uniform(random.PRNGKey(0), (10, 1)) act_scale = 8. scaled_activations = activations * act_scale bounds = 6. activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) scaled_activations = QuantOps.create_inputs_fake_quant( inputs=scaled_activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds * act_scale, prec=prec)) onp.testing.assert_array_equal(activations * act_scale, scaled_activations)
def quantized_softmax(a): # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing # intermediate values. Note this differs from the log-domain # implementation of softmax used above. quant_hparams = softmax_hparams.quant_hparams fp_quant_config = QuantOps.FloatQuant(is_scaled=False, fp_spec=quant_hparams.prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config, bounds=None) a = quant_ops.to_quantized(a, dtype=dtype) # Note that the max of a quantized vector is necessarily also quantized to # the same precision since the max of a vector must be an existing element # of the vector, so we don't need to explicitly insert a quantization # operator to the output of the max reduction. a_max = jnp.max(a, axis=norm_dims, keepdims=True) a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype) a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype) sum_exp_quantized_reduction = quantization.quantized_sum( a_exp, axis=norm_dims, keepdims=True, prec=quant_hparams.reduction_prec) sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction, dtype=dtype) inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp), dtype=dtype) a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype) return a_softmax.astype(dtype)
def quantized_layernorm(x): prec = hparams.quant_hparams.prec fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None) def to_quantized(x): return quant_ops.to_quantized(x, dtype=dtype) # If epsilon is too small to represent in the quantized format, we set it # to the minimal representative non-zero value to avoid the possibility of # dividing by zero. fp_bounds = quantization.fp_cast.get_bounds( prec.exp_min, prec.exp_max, prec.sig_bits) epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound) quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype)) # If the reciprocal of the quantized number of features is too small to # represent in the quantized format, we set it to the minimal # representative nonzero value so that the mean and variance are not # trivially 0. num_features_quantized = to_quantized( jnp.array(num_features, dtype=dtype)) num_features_recip_quantized = to_quantized( jnp.reciprocal(num_features_quantized)) num_features_recip_quantized = jax.lax.cond( jax.lax.eq(num_features_recip_quantized, 0.0), lambda _: quantized_epsilon, lambda _: num_features_recip_quantized, None) x_quantized = to_quantized(x) x_sum_quantized_reduction = quantization.quantized_sum( x_quantized, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sum = to_quantized(x_sum_quantized_reduction) mean = to_quantized(x_sum * num_features_recip_quantized) x_minus_mean = to_quantized(x - mean) x_sq = to_quantized(lax.square(x_minus_mean)) x_sq_sum_quantized_reduction = quantization.quantized_sum( x_sq, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sq_sum = to_quantized(x_sq_sum_quantized_reduction) var = to_quantized(x_sq_sum * num_features_recip_quantized) # Prevent division by zero. var_plus_epsilon = to_quantized(var + quantized_epsilon) mul = to_quantized(lax.rsqrt(var_plus_epsilon)) if self.use_scale: quantized_scale_param = to_quantized(scale_param) mul = to_quantized(mul * quantized_scale_param) y = to_quantized(x_minus_mean * mul) if self.use_bias: quantized_bias_param = to_quantized(bias_param) y = to_quantized(y + quantized_bias_param) return y.astype(self.dtype)
def test_float_weights_quantization(self, prec): # Tests that quantized and rescaled float weights are close to original # weights. weights = jnp.array( fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 1)))) rescaled_weights = QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=None)) test_utils.assert_all_close_prec(weights, rescaled_weights, prec=prec)
def test_attributes_create_acts_op_fp( self, act_distribution, use_hparams_bounds, ): inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) fp_quant = QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-15, exp_max=15, sig_bits=2, ), ) if use_hparams_bounds: bounds = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=1, absdev_coeff=0, mix_coeff=1, reset_stats=True, ema_coeff=None, use_cams=False, granularity=quant_config.QuantGranularity.per_tensor) else: bounds = 6.0 hparams = QuantOps.ActHParams( input_distribution=act_distribution, bounds=bounds, prec=fp_quant, half_shift=False) class TestModule(nn.Module): hparams: QuantOps.ActHParams @nn.compact def __call__(self, inputs): return QuantOps.create_input_ops( inputs, hparams=hparams, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False)) test_module = TestModule(hparams=hparams) state = test_module.init(jax.random.PRNGKey(0), inputs=inputs) act_quant_op = test_module.apply(state, inputs=inputs) act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype) act_quant_expected = fp_cast.downcast_sat_ftz( act_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated)
def test_weight_scale_shape_is_expected(self, axis): # Tests if scale is as expected for weights quantization. num_features = 4 expected_scale_shape = (1, 1) if axis is None else (1, num_features) # Weight Quantization weights = jnp.array( fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, num_features)))) _ = QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams( prec=8.0, axis=axis, expected_scale_shape=expected_scale_shape))
def test_per_feature_dim_scale_invariance_weight_quantization(self, prec): # Scaling each channel of weights by a different power of 2, should scale # the respective channel of output by the same scale. weights = random.uniform(random.PRNGKey(0), (3, 4)) weight_scale = 2**jnp.arange(4)[jnp.newaxis, :] scaled_weights = weights * weight_scale weights = quantization.QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=0)) scaled_weights = quantization.QuantOps.create_weights_fake_quant( w=scaled_weights, weight_params=QuantOps.WeightParams(prec=prec, axis=0)) onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
def test_quantized_dot_raises_with_mixed_dtype(self, quant_type): weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=jnp.array([[3.0, 1.5]]), prec=4) act = self.lhs.astype(jnp.bfloat16) w = self.rhs.astype(jnp.float32) with self.assertRaises(TypeError): quantization.quantized_dot(w=w, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True)
def test_scale_invariance_weight_quantization(self, prec): # Scaling weights by power of 2, should scale the output by the same scale. weights = random.uniform(random.PRNGKey(0), (10, 1)) weight_scale = 16 scaled_weights = weights * weight_scale weights = QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=None)) scaled_weights = QuantOps.create_weights_fake_quant( w=scaled_weights, weight_params=QuantOps.WeightParams(prec=prec, axis=None)) onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
def test_quantized_dot_no_quant(self): act_hparams = QuantOps.ActHParams(input_distribution='symmetric', bounds=-1.0, prec=4) weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act = jnp.array([[-5.0]]) w = jnp.array([[-4.99]]) res = quantization.quantized_dot(w=w, act=act, quant_type=quantization.QuantType.aqt, weight_params=weight_params, act_hparams=act_hparams, get_bounds_params=None, prefer_int8_to_int32_dot=True) onp.testing.assert_allclose(res, act * w)
def test_int_positive_act_quantization(self, prec): # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec # quantizes correctly. upper_bound = 2**(prec - 3) activations = random.randint(random.PRNGKey(0), (10, 1), 0, upper_bound) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.positive, bounds=upper_bound, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations)
def test_attributes_create_weights_ops(self, weight_range, weight_shape, prec): weights = jnp.array( fp32( onp.random.uniform(weight_range[0], weight_range[1], size=weight_shape))) axis = 0 if weight_shape[1] != 1 else None weights_quant = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=axis)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal(jnp.squeeze(weights_quant._scale), (2**(prec - 1) - 1) / max_weight) self.assertEqual(weights_quant._symmetric, True) self.assertEqual(weights_quant._prec, prec)
def __call__(self, inputs): return QuantOps.create_input_ops( inputs, hparams=hparams, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False))
def test_int_symmetric_act_quantization(self, prec): # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1 # quantizes correctly. bounds = 2**(prec - 1) - 1 activations = random.randint(random.PRNGKey(0), (10, 1), -bounds, bounds) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations)
def test_attributes_create_symmetric(self, bounds, prec): bounds = jnp.array(bounds) act_signed = QuantOps.create_symmetric(bounds=bounds, prec=prec) onp.testing.assert_array_equal(act_signed._scale, (2**(prec - 1) - 1) / bounds) self.assertEqual(act_signed._symmetric, True) self.assertEqual(act_signed._prec, prec)
def test_quantized_dot_has_correct_dtype(self, input_dtype, act_prec, quant_type): weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=jnp.array([[3.0, 1.5]]), prec=act_prec) act = self.lhs.astype(input_dtype) w = self.rhs.astype(input_dtype) output = quantization.quantized_dot(w=w, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True) self.assertEqual(output.dtype, input_dtype)
def test_quantized_dynamic_dot_general_should_call_inputs_quantization( self, mock_act_fq, lhs_act_prec, rhs_act_prec, strategy=QuantType.fake_quant): mock_act_fq.side_effect = lambda inputs, hparams, get_bounds_params: inputs # pylint: disable=g-long-ternary lhs_act_hparams = QuantOps.ActHParams( bounds=6., prec=lhs_act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if lhs_act_prec else None rhs_act_hparams = QuantOps.ActHParams( bounds=6., prec=rhs_act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if rhs_act_prec else None # pylint: enable=g-long-ternary get_bounds_params = GetBounds.Params(update_stats=False, update_bounds=False) quantization.quantized_dynamic_dot_general( lhs_act=self.lhs_act, rhs_act=self.rhs_act, quant_type=strategy, dot_dimension_numbers=self.dimension_numbers, lhs_act_hparams=lhs_act_hparams, lhs_get_bounds_params=get_bounds_params, rhs_act_hparams=rhs_act_hparams, rhs_get_bounds_params=get_bounds_params, ) calls = [] for prec in [lhs_act_prec, rhs_act_prec]: if prec is not None: act_hparams = QuantOps.ActHParams(bounds=6., prec=prec, input_distribution=mock.ANY, half_shift=False) calls.append( mock.call(mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)) self.assertLen(calls, mock_act_fq.call_count) mock_act_fq.assert_has_calls(calls, any_order=True)
def test_full_range_int_weight_quantization(self, prec): # Integer weights in full range [-maxmin_signed_int, maxmin_signed_int] # quantizes correctly. minval = -2**(prec - 1) + 1 maxval = 2**(prec - 1) - 1 weights = random.randint(random.PRNGKey(0), (10, 1), minval, maxval + 1) weights = weights.at[0, :].set(maxval) weight_quant = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams( prec=prec, axis=None, half_shift=False)) quantized_weights = weight_quant.to_quantized(weights, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(quantized_weights[0], (2**(prec - 1.0) - 1.0)) rescaled_weights = weight_quant.from_quantized( quantized_weights, dtype=jnp.float32) onp.testing.assert_array_equal(weights, rescaled_weights)
def test_dynamic_quantized_dot_general_raises_with_mixed_dtype(self): lhs_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=2.0, prec=4, half_shift=False) rhs_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=1.5, prec=4, half_shift=False) lhs_act = self.lhs.astype(jnp.bfloat16) rhs_act = self.rhs.astype(jnp.float32) with self.assertRaises(TypeError): quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1,), (0,)), ((), ())), quant_type=QuantType.aqt)
def test_inputs_scale_shape_is_expected(self): # Inputs quantization inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) bounds = 6.0 expected_inputs_scale_shape = () _ = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=8.0), get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False, expected_bounds_shape=expected_inputs_scale_shape))
def test_lax_dot_has_integer_inputs_in_quantized_dot( self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot, prec): weight_params = QuantOps.WeightParams(prec=prec, axis=(0, ), half_shift=False) act_params = QuantOps.ActHParams(input_distribution=act_distribution, bounds=jnp.array([[3.0, 1.5]]), prec=prec, half_shift=False) act = self.lhs if act_distribution == 'positive': act = jnp.abs(act) # We need this context manager to stop Jax from trying to compile the arms # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always # try to compile the functions passed to `lax.cond`, even if outside of a # JITed context. JIT compilation is incompatible with using a mock for the # call to 'dot_general' because during compilation Jax will expect # 'dot_general' to return a tracer and will throw an error if it returns a # mock instead. By explicily using jax.disable_jit, Jax will not try to # compile the arms to lax.cond and so using a mock will work fine. with jax.disable_jit(): quantization.quantized_dot( w=self.rhs, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=QuantType.aqt, prefer_int8_to_int32_dot=prefer_int8_to_int32_dot) act_inputs, weight_inputs = mock_dot_general.call_args[0] self.assert_is_integer_in_range(act_inputs, prec=prec, distribution=act_distribution) self.assert_is_integer_in_range(weight_inputs, prec=prec, distribution='symmetric') if prefer_int8_to_int32_dot and not (act_distribution == 'positive' and prec == 8): expected_input_dtype = jnp.int8 else: expected_input_dtype = jnp.float32 self.assertEqual(act_inputs.dtype, expected_input_dtype) self.assertEqual(weight_inputs.dtype, expected_input_dtype)
def test_positive_activation_quantization_clips_outside_bounds(self, prec): # Activation values less than 0 get clipped to 0, and values greater than # upper_bound get clipped to upper_bound relu6 = QuantOps.create_positive(bounds=6.0, prec=prec) activation = jnp.array(fp32([-0.5, 6.2, 3.141])) quantized_activations = relu6.to_quantized(activation, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(quantized_activations[0:2], [0.0, 2**prec - 1]) activations = relu6.from_quantized(quantized_activations, dtype=jnp.float32) max_clipped_val = (2**prec - 1) * (6.0 / 2**prec) onp.testing.assert_array_equal(activations[0:2], [0.0, max_clipped_val])
def test_dynamic_quantized_dot_general_has_correct_dtype( self, input_dtype, act_prec, quant_type): lhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=2.0, prec=act_prec) rhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=1.5, prec=act_prec) lhs_act = self.lhs.astype(input_dtype) rhs_act = self.rhs.astype(input_dtype) output = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=quant_type) self.assertEqual(output.dtype, input_dtype)
def test_quantized_dot_general_should_call_weights_and_inputs_quantization( self, mock_act_fq, mock_w_fq, weight_prec, act_prec, strategy=QuantType.fake_quant): mock_w_fq.side_effect = lambda inputs, **_: inputs mock_act_fq.side_effect = lambda inputs, **_: inputs weight_params = QuantOps.WeightParams( prec=weight_prec, axis=None, half_shift=False) act_hparams = QuantOps.ActHParams( # pylint: disable=g-long-ternary bounds=6., prec=act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if act_prec else None get_bounds_params = GetBounds.Params( update_stats=False, update_bounds=False) quantization.quantized_dot( w=self.weight, act=self.act, quant_type=strategy, weight_params=weight_params, act_hparams=act_hparams, get_bounds_params=get_bounds_params, prefer_int8_to_int32_dot=True) quantized_type = strategy.to_jax_type() mock_w_fq.assert_called_with( mock.ANY, weight_params=weight_params, quantized_type=quantized_type, fake_dependency=mock.ANY) if act_hparams: mock_act_fq.assert_called_with( mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params) else: mock_act_fq.assert_not_called()
def test_per_feature_dim_scale_invariance_pos_activation_quantization( self, prec): # Scaling each channel of activations by a different power of 2 and upper # bound with same scale, should scale the respective channel of output by # the same scale. activations = random.uniform(random.PRNGKey(0), (3, 4)) act_scale = 2**jnp.arange(4) scaled_activations = activations * act_scale[jnp.newaxis, :] upper_bound = 6.0 * jnp.ones((3, 4), jnp.float32) act_quant_ops = QuantOps.create_positive(bounds=upper_bound, prec=prec) activations = act_quant_ops.fake_quant(activations, quantized_type=SCALE_DTYPE) scaled_act_quant_ops = QuantOps.create_positive( bounds=upper_bound * act_scale[jnp.newaxis, :], prec=prec) scaled_activations = scaled_act_quant_ops.fake_quant( scaled_activations, quantized_type=SCALE_DTYPE) onp.testing.assert_array_equal(activations * act_scale[jnp.newaxis, :], scaled_activations)
def test_quantized_dot_general_aqt(self, act_bounds, weight_prec, weight_axis): # With a high enough precision, we expect results from fakequant and AQT to # be very similar. weight_params = QuantOps.WeightParams( prec=weight_prec, axis=weight_axis, half_shift=False) if act_bounds is None: act_params = None else: act_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=jnp.array(act_bounds), prec=16, half_shift=False) lhs_ndims_3 = jnp.array( fp32(2.0 * onp.random.uniform(0, 1.0, size=(4, 3, 2)))) def quantized_matmul(quant_type): return quantization.quantized_dot_general( w=self.rhs, act=lhs_ndims_3, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, dimension_numbers=(((lhs_ndims_3.ndim - 1,), (0,)), ((), ())), prefer_int8_to_int32_dot=True) aqt_result = quantized_matmul(QuantType.aqt) self.assertEqual(aqt_result.shape, (4, 3, 4)) fakequant_result = quantized_matmul(QuantType.fake_quant) onp.testing.assert_allclose( aqt_result, fakequant_result, rtol=1e-2, err_msg='AQT and fakequant significantly disagree')