def quantized_softmax(a): # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing # intermediate values. Note this differs from the log-domain # implementation of softmax used above. quant_hparams = softmax_hparams.quant_hparams fp_quant_config = QuantOps.FloatQuant(is_scaled=False, fp_spec=quant_hparams.prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config, bounds=None) a = quant_ops.to_quantized(a, dtype=dtype) # Note that the max of a quantized vector is necessarily also quantized to # the same precision since the max of a vector must be an existing element # of the vector, so we don't need to explicitly insert a quantization # operator to the output of the max reduction. a_max = jnp.max(a, axis=norm_dims, keepdims=True) a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype) a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype) sum_exp_quantized_reduction = quantization.quantized_sum( a_exp, axis=norm_dims, keepdims=True, prec=quant_hparams.reduction_prec) sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction, dtype=dtype) inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp), dtype=dtype) a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype) return a_softmax.astype(dtype)
def quantized_layernorm(x): prec = hparams.quant_hparams.prec fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None) def to_quantized(x): return quant_ops.to_quantized(x, dtype=dtype) # If epsilon is too small to represent in the quantized format, we set it # to the minimal representative non-zero value to avoid the possibility of # dividing by zero. fp_bounds = quantization.fp_cast.get_bounds( prec.exp_min, prec.exp_max, prec.sig_bits) epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound) quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype)) # If the reciprocal of the quantized number of features is too small to # represent in the quantized format, we set it to the minimal # representative nonzero value so that the mean and variance are not # trivially 0. num_features_quantized = to_quantized( jnp.array(num_features, dtype=dtype)) num_features_recip_quantized = to_quantized( jnp.reciprocal(num_features_quantized)) num_features_recip_quantized = jax.lax.cond( jax.lax.eq(num_features_recip_quantized, 0.0), lambda _: quantized_epsilon, lambda _: num_features_recip_quantized, None) x_quantized = to_quantized(x) x_sum_quantized_reduction = quantization.quantized_sum( x_quantized, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sum = to_quantized(x_sum_quantized_reduction) mean = to_quantized(x_sum * num_features_recip_quantized) x_minus_mean = to_quantized(x - mean) x_sq = to_quantized(lax.square(x_minus_mean)) x_sq_sum_quantized_reduction = quantization.quantized_sum( x_sq, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sq_sum = to_quantized(x_sq_sum_quantized_reduction) var = to_quantized(x_sq_sum * num_features_recip_quantized) # Prevent division by zero. var_plus_epsilon = to_quantized(var + quantized_epsilon) mul = to_quantized(lax.rsqrt(var_plus_epsilon)) if self.use_scale: quantized_scale_param = to_quantized(scale_param) mul = to_quantized(mul * quantized_scale_param) y = to_quantized(x_minus_mean * mul) if self.use_bias: quantized_bias_param = to_quantized(bias_param) y = to_quantized(y + quantized_bias_param) return y.astype(self.dtype)
def test_attributes_create_acts_op_fp( self, act_distribution, use_hparams_bounds, ): inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) fp_quant = QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-15, exp_max=15, sig_bits=2, ), ) if use_hparams_bounds: bounds = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=1, absdev_coeff=0, mix_coeff=1, reset_stats=True, ema_coeff=None, use_cams=False, granularity=quant_config.QuantGranularity.per_tensor) else: bounds = 6.0 hparams = QuantOps.ActHParams( input_distribution=act_distribution, bounds=bounds, prec=fp_quant, half_shift=False) class TestModule(nn.Module): hparams: QuantOps.ActHParams @nn.compact def __call__(self, inputs): return QuantOps.create_input_ops( inputs, hparams=hparams, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False)) test_module = TestModule(hparams=hparams) state = test_module.init(jax.random.PRNGKey(0), inputs=inputs) act_quant_op = test_module.apply(state, inputs=inputs) act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype) act_quant_expected = fp_cast.downcast_sat_ftz( act_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated)
class QuantOpsTest(parameterized.TestCase): def setUp(self): super(QuantOpsTest, self).setUp() quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = True @parameterized.named_parameters( dict(testcase_name='prec_2', bounds=6.0, prec=2), dict(testcase_name='prec_4', bounds=6.0, prec=4), dict(testcase_name='prec_8', bounds=6.0, prec=8), dict(testcase_name='2_features_prec_8', bounds=[6., 12.], prec=8), ) def test_attributes_create_positive(self, bounds, prec): bounds = jnp.array(bounds) relu6 = QuantOps.create_positive(bounds=bounds, prec=prec) onp.testing.assert_array_equal(relu6._scale, 2**prec / bounds) self.assertEqual(relu6._symmetric, False) self.assertEqual(relu6._prec, prec) @parameterized.named_parameters( dict(testcase_name='prec_2', bounds=6.0, prec=2), dict(testcase_name='prec_4', bounds=6.0, prec=4), dict(testcase_name='prec_8', bounds=6.0, prec=8), dict(testcase_name='2_features_prec_8', bounds=[6., 12.], prec=8), ) def test_attributes_create_symmetric(self, bounds, prec): bounds = jnp.array(bounds) act_signed = QuantOps.create_symmetric(bounds=bounds, prec=prec) onp.testing.assert_array_equal(act_signed._scale, (2**(prec - 1) - 1) / bounds) self.assertEqual(act_signed._symmetric, True) self.assertEqual(act_signed._prec, prec) @parameterized.named_parameters( dict( testcase_name='fp8_143', weight_range=[2.0, 64.0], weight_shape=(10, 1), fp_quant=QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-11, exp_max=4, sig_bits=3, ), ), ), dict( testcase_name='fp8_152', weight_range=[2.0, 64.0], weight_shape=(10, 1), fp_quant=QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-23, exp_max=8, sig_bits=2, ), ), ), ) def test_attributes_create_weights_op_fp( self, weight_range, weight_shape, fp_quant, ): weights = jnp.array( fp32(onp.random.uniform(*weight_range, size=weight_shape))) axis = None if weight_shape[1] == 1 else 0 weights_quant_op = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams(prec=fp_quant, axis=axis)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal( jnp.squeeze(weights_quant_op._scale), jnp.exp2(-jnp.floor(jnp.log2(max_weight)))) self.assertEqual(weights_quant_op._symmetric, True) self.assertIs(weights_quant_op._prec, fp_quant) weights_scaled = (weights * weights_quant_op._scale).astype( weights.dtype) weights_quant_expected = fp_cast.downcast_sat_ftz( weights_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) weights_quant_calculated = weights_quant_op.to_quantized( weights, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(weights_quant_expected, weights_quant_calculated) # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated # quantized weights are zero. sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1) onp.testing.assert_array_equal( weights_quant_calculated.view(jnp.int32) & sig_mask, jnp.zeros_like(weights)) @parameterized.named_parameters( dict(testcase_name='pos_weight_prec_2', weight_range=[2.0, 10.0], weight_shape=(10, 1), prec=2), dict(testcase_name='pos_weight_prec_4', weight_range=[2.0, 10.0], weight_shape=(10, 1), prec=4), dict(testcase_name='pos_weight_prec_8', weight_range=[2.0, 10.0], weight_shape=(10, 1), prec=8), dict(testcase_name='neg_weight_prec_8', weight_range=[-12.0, 2.0], weight_shape=(10, 1), prec=8), dict(testcase_name='neg_weight_2_features_prec_8', weight_range=[-12.0, 2.0], weight_shape=(10, 2), prec=8), ) def test_attributes_create_weights_ops(self, weight_range, weight_shape, prec): weights = jnp.array( fp32( onp.random.uniform(weight_range[0], weight_range[1], size=weight_shape))) axis = 0 if weight_shape[1] != 1 else None weights_quant = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=axis)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal(jnp.squeeze(weights_quant._scale), (2**(prec - 1) - 1) / max_weight) self.assertEqual(weights_quant._symmetric, True) self.assertEqual(weights_quant._prec, prec) @parameterized.named_parameters( dict(testcase_name='per_layer_quant', axis=None), dict(testcase_name='per_channel_quant', axis=(0, ))) def test_weight_scale_shape_is_expected(self, axis): # Tests if scale is as expected for weights quantization. num_features = 4 expected_scale_shape = (1, 1) if axis is None else (1, num_features) # Weight Quantization weights = jnp.array( fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, num_features)))) _ = QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams( prec=8.0, axis=axis, expected_scale_shape=expected_scale_shape)) def test_inputs_scale_shape_is_expected(self): # Inputs quantization inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) bounds = 6.0 expected_inputs_scale_shape = () _ = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=8.0), get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False, expected_bounds_shape=expected_inputs_scale_shape)) @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_positive_activation_quantization_clips_outside_bounds(self, prec): # Activation values less than 0 get clipped to 0, and values greater than # upper_bound get clipped to upper_bound relu6 = QuantOps.create_positive(bounds=6.0, prec=prec) activation = jnp.array(fp32([-0.5, 6.2, 3.141])) quantized_activations = relu6.to_quantized(activation, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(quantized_activations[0:2], [0.0, 2**prec - 1]) activations = relu6.from_quantized(quantized_activations, dtype=jnp.float32) max_clipped_val = (2**prec - 1) * (6.0 / 2**prec) onp.testing.assert_array_equal(activations[0:2], [0.0, max_clipped_val]) @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_per_feature_dim_unsigned_activation_quantization_clips_outside_bounds( self, prec): # Activation values less than -upper_bound get clipped to -upper_bound, and # values greater than upper_bound get clipped to upper_bound act_quant = QuantOps.create_symmetric(bounds=jnp.array([[6.0, 8.0]]), prec=prec) activation = jnp.array(fp32([[-7, -8.9], [6.2, 9.4], [0, 0.]])) quantized_activations = act_quant.to_quantized(activation, dtype=SCALE_DTYPE) onp.testing.assert_array_equal( quantized_activations, jnp.array([[-2**(prec - 1.0) + 1.0], [2**(prec - 1.0) - 1.0], [0.0]]) * jnp.array([[1., 1.]])) activations = act_quant.from_quantized(quantized_activations, dtype=jnp.float32) onp.testing.assert_array_equal(activations, [[-6.0, -8.0], [6.0, 8.], [0, 0.]]) @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_scale_invariance_signed_activation_quantization(self, prec): # Scaling activation by power of 2 and bounds by same factor, # should scale the output by the same scale. activations = random.uniform(random.PRNGKey(0), (10, 1)) act_scale = 8. scaled_activations = activations * act_scale bounds = 6. activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) scaled_activations = QuantOps.create_inputs_fake_quant( inputs=scaled_activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds * act_scale, prec=prec)) onp.testing.assert_array_equal(activations * act_scale, scaled_activations) @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_per_feature_dim_scale_invariance_pos_activation_quantization( self, prec): # Scaling each channel of activations by a different power of 2 and upper # bound with same scale, should scale the respective channel of output by # the same scale. activations = random.uniform(random.PRNGKey(0), (3, 4)) act_scale = 2**jnp.arange(4) scaled_activations = activations * act_scale[jnp.newaxis, :] upper_bound = 6.0 * jnp.ones((3, 4), jnp.float32) act_quant_ops = QuantOps.create_positive(bounds=upper_bound, prec=prec) activations = act_quant_ops.fake_quant(activations, quantized_type=SCALE_DTYPE) scaled_act_quant_ops = QuantOps.create_positive( bounds=upper_bound * act_scale[jnp.newaxis, :], prec=prec) scaled_activations = scaled_act_quant_ops.fake_quant( scaled_activations, quantized_type=SCALE_DTYPE) onp.testing.assert_array_equal(activations * act_scale[jnp.newaxis, :], scaled_activations) @parameterized.named_parameters(dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_int_positive_act_quantization(self, prec): # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec # quantizes correctly. upper_bound = 2**(prec - 3) activations = random.randint(random.PRNGKey(0), (10, 1), 0, upper_bound) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.positive, bounds=upper_bound, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations) @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_int_symmetric_act_quantization(self, prec): # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1 # quantizes correctly. bounds = 2**(prec - 1) - 1 activations = random.randint(random.PRNGKey(0), (10, 1), -bounds, bounds) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations) @parameterized.named_parameters(dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_float_weights_quantization(self, prec): # Tests that quantized and rescaled float weights are close to original # weights. weights = jnp.array( fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 1)))) rescaled_weights = QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=None)) test_utils.assert_all_close_prec(weights, rescaled_weights, prec=prec) @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_full_range_int_weight_quantization(self, prec): # Integer weights in full range [-maxmin_signed_int, maxmin_signed_int] # quantizes correctly. minval = -2**(prec - 1) + 1 maxval = 2**(prec - 1) - 1 weights = random.randint(random.PRNGKey(0), (10, 1), minval, maxval + 1) weights = jax.ops.index_update(weights, jax.ops.index[0, :], maxval) weight_quant = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=None)) quantized_weights = weight_quant.to_quantized(weights, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(quantized_weights[0], (2**(prec - 1.0) - 1.0)) rescaled_weights = weight_quant.from_quantized(quantized_weights, dtype=jnp.float32) onp.testing.assert_array_equal(weights, rescaled_weights) @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_scale_invariance_weight_quantization(self, prec): # Scaling weights by power of 2, should scale the output by the same scale. weights = random.uniform(random.PRNGKey(0), (10, 1)) weight_scale = 16 scaled_weights = weights * weight_scale weights = QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=None)) scaled_weights = QuantOps.create_weights_fake_quant( w=scaled_weights, weight_params=QuantOps.WeightParams(prec=prec, axis=None)) onp.testing.assert_array_equal(weights * weight_scale, scaled_weights) @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_per_feature_dim_scale_invariance_weight_quantization(self, prec): # Scaling each channel of weights by a different power of 2, should scale # the respective channel of output by the same scale. weights = random.uniform(random.PRNGKey(0), (3, 4)) weight_scale = 2**jnp.arange(4)[jnp.newaxis, :] scaled_weights = weights * weight_scale weights = quantization.QuantOps.create_weights_fake_quant( w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=0)) scaled_weights = quantization.QuantOps.create_weights_fake_quant( w=scaled_weights, weight_params=QuantOps.WeightParams(prec=prec, axis=0)) onp.testing.assert_array_equal(weights * weight_scale, scaled_weights) def test_no_quantization(self): # If initial_bound==-1 when using GetBounds, then create_inputs_fake_quant # should be a no-op. inputs = jnp.array([[.3, 1.4], [-5.2, 4.0]]) bounds = get_bounds.GetBounds.Hyper( initial_bound=-1, stddev_coeff=1, absdev_coeff=0, mix_coeff=1, reset_stats=True, ema_coeff=None, use_cams=False, granularity=quant_config.QuantGranularity.per_tensor) hparams = quantization.QuantOps.ActHParams( input_distribution='symmetric', bounds=bounds, prec=4) # The call to create_inputs_fake_quant has to occur from within a Flax # module since it calls GetBounds, which is itself a Flax module. # Thus we create a wrapper module for testing. class TestModule(nn.Module): hparams: quantization.QuantOps.ActHParams @nn.compact def __call__(self, inputs): return quantization.QuantOps.create_inputs_fake_quant( inputs, hparams=hparams, get_bounds_params=GetBounds.Params(update_stats=True, update_bounds=False)) test_module = TestModule(hparams=hparams) state = test_module.init(jax.random.PRNGKey(0), inputs=inputs) inputs_after_fake_quant, _ = test_module.apply(state, inputs=inputs, mutable=True) onp.testing.assert_array_equal(inputs, inputs_after_fake_quant)
class QuantOpsTest(parameterized.TestCase): def setUp(self): super(QuantOpsTest, self).setUp() quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = True @parameterized.named_parameters( dict(testcase_name='prec_2', bounds=6.0, prec=2), dict(testcase_name='prec_4', bounds=6.0, prec=4), dict(testcase_name='prec_8', bounds=6.0, prec=8), dict( testcase_name='2_features_prec_8', bounds=[6., 12.], prec=8), ) def test_attributes_create_positive(self, bounds, prec): bounds = jnp.array(bounds) relu6 = QuantOps.create_positive(bounds=bounds, prec=prec) onp.testing.assert_array_equal(relu6._scale, 2**prec / bounds) self.assertEqual(relu6._symmetric, False) self.assertEqual(relu6._prec, prec) @parameterized.named_parameters( dict(testcase_name='prec_2', bounds=6.0, prec=2), dict(testcase_name='prec_4', bounds=6.0, prec=4), dict(testcase_name='prec_8', bounds=6.0, prec=8), dict( testcase_name='2_features_prec_8', bounds=[6., 12.], prec=8), ) def test_attributes_create_symmetric(self, bounds, prec): bounds = jnp.array(bounds) act_signed = QuantOps.create_symmetric( bounds=bounds, prec=prec, half_shift=False) onp.testing.assert_array_equal(act_signed._scale, (2**(prec - 1) - 1) / bounds) self.assertEqual(act_signed._symmetric, True) self.assertEqual(act_signed._prec, prec) @parameterized.named_parameters( dict( testcase_name='fp8_143', weight_range=[2.0, 64.0], weight_shape=(10, 1), fp_quant=QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-11, exp_max=4, sig_bits=3, ), ), ), dict( testcase_name='fp8_152', weight_range=[2.0, 64.0], weight_shape=(10, 1), fp_quant=QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-23, exp_max=8, sig_bits=2, ), ), ), ) def test_attributes_create_weights_op_fp( self, weight_range, weight_shape, fp_quant, ): weights = jnp.array( fp32(onp.random.uniform(*weight_range, size=weight_shape))) axis = None if weight_shape[1] == 1 else 0 weights_quant_op = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams( prec=fp_quant, axis=axis, half_shift=False)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal( jnp.squeeze(weights_quant_op._scale), jnp.exp2(-jnp.floor(jnp.log2(max_weight)))) self.assertEqual(weights_quant_op._symmetric, True) self.assertIs(weights_quant_op._prec, fp_quant) weights_scaled = (weights * weights_quant_op._scale).astype(weights.dtype) weights_quant_expected = fp_cast.downcast_sat_ftz( weights_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) weights_quant_calculated = weights_quant_op.to_quantized( weights, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(weights_quant_expected, weights_quant_calculated) # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated # quantized weights are zero. sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1) onp.testing.assert_array_equal( weights_quant_calculated.view(jnp.int32) & sig_mask, jnp.zeros_like(weights)) @parameterized.named_parameters( dict( testcase_name='fp_act_symmetric', act_distribution='symmetric', use_hparams_bounds=False, ), # TODO(b/193561347): FP quantization with positive input distribution is # not supported yet dict( testcase_name='fp_act_positive', act_distribution='positive', use_hparams_bounds=False, ), dict( testcase_name='fp_act_symmetric_hyper_bounds', act_distribution='symmetric', use_hparams_bounds=True, ), dict( testcase_name='fp_act_positive_hyper_bounds', act_distribution='positive', use_hparams_bounds=True, ), ) def test_attributes_create_acts_op_fp( self, act_distribution, use_hparams_bounds, ): inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) fp_quant = QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-15, exp_max=15, sig_bits=2, ), ) if use_hparams_bounds: bounds = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=1, absdev_coeff=0, mix_coeff=1, reset_stats=True, ema_coeff=None, use_cams=False, granularity=quant_config.QuantGranularity.per_tensor) else: bounds = 6.0 hparams = QuantOps.ActHParams( input_distribution=act_distribution, bounds=bounds, prec=fp_quant, half_shift=False) class TestModule(nn.Module): hparams: QuantOps.ActHParams @nn.compact def __call__(self, inputs): return QuantOps.create_input_ops( inputs, hparams=hparams, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False)) test_module = TestModule(hparams=hparams) state = test_module.init(jax.random.PRNGKey(0), inputs=inputs) act_quant_op = test_module.apply(state, inputs=inputs) act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype) act_quant_expected = fp_cast.downcast_sat_ftz( act_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated) @parameterized.named_parameters( dict( testcase_name='pos_weight_prec_2', weight_range=[2.0, 10.0], weight_shape=(10, 1), prec=2), dict( testcase_name='pos_weight_prec_4', weight_range=[2.0, 10.0], weight_shape=(10, 1), prec=4), dict( testcase_name='pos_weight_prec_8', weight_range=[2.0, 10.0], weight_shape=(10, 1), prec=8), dict( testcase_name='neg_weight_prec_8', weight_range=[-12.0, 2.0], weight_shape=(10, 1), prec=8), dict( testcase_name='neg_weight_2_features_prec_8', weight_range=[-12.0, 2.0], weight_shape=(10, 2), prec=8), ) def test_attributes_create_weights_ops(self, weight_range, weight_shape, prec): weights = jnp.array( fp32( onp.random.uniform( weight_range[0], weight_range[1], size=weight_shape))) axis = 0 if weight_shape[1] != 1 else None weights_quant = QuantOps.create_weights_ops( w=weights, weight_params=QuantOps.WeightParams( prec=prec, axis=axis, half_shift=False)) max_weight = onp.max(abs(weights), axis=0) onp.testing.assert_array_equal( jnp.squeeze(weights_quant._scale), (2**(prec - 1) - 1) / max_weight) self.assertEqual(weights_quant._symmetric, True) self.assertEqual(weights_quant._prec, prec)
class ActQuantizationTest(parameterized.TestCase): def test_inputs_scale_shape_is_expected(self): # Inputs quantization inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) bounds = 6.0 expected_inputs_scale_shape = () _ = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, bounds=bounds, prec=8.0, half_shift=False), get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False, expected_bounds_shape=expected_inputs_scale_shape)) @parameterized.named_parameters( dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_positive_activation_quantization_clips_outside_bounds(self, prec): # Activation values less than 0 get clipped to 0, and values greater than # upper_bound get clipped to upper_bound relu6 = QuantOps.create_positive(bounds=6.0, prec=prec) activation = jnp.array(fp32([-0.5, 6.2, 3.141])) quantized_activations = relu6.to_quantized(activation, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(quantized_activations[0:2], [0.0, 2**prec - 1]) activations = relu6.from_quantized(quantized_activations, dtype=jnp.float32) max_clipped_val = (2**prec - 1) * (6.0 / 2**prec) onp.testing.assert_array_equal(activations[0:2], [0.0, max_clipped_val]) @parameterized.named_parameters( dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8) ) def test_per_feature_dim_unsigned_activation_quantization_clips_outside_bounds( self, prec): # Activation values less than -upper_bound get clipped to -upper_bound, and # values greater than upper_bound get clipped to upper_bound act_quant = QuantOps.create_symmetric( bounds=jnp.array([[6.0, 8.0]]), prec=prec, half_shift=False) activation = jnp.array(fp32([[-7, -8.9], [6.2, 9.4], [0, 0.]])) quantized_activations = act_quant.to_quantized( activation, dtype=SCALE_DTYPE) onp.testing.assert_array_equal( quantized_activations, jnp.array([[-2**(prec - 1.0) + 1.0], [2**(prec - 1.0) - 1.0], [0.0]]) * jnp.array([[1., 1.]])) activations = act_quant.from_quantized( quantized_activations, dtype=jnp.float32) onp.testing.assert_array_equal(activations, [[-6.0, -8.0], [6.0, 8.], [0, 0.]]) @parameterized.named_parameters( dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8) ) def test_scale_invariance_signed_activation_quantization(self, prec): # Scaling activation by power of 2 and bounds by same factor, # should scale the output by the same scale. activations = random.uniform(random.PRNGKey(0), (10, 1)) act_scale = 8. scaled_activations = activations * act_scale bounds = 6. activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, bounds=bounds, prec=prec, half_shift=False)) scaled_activations = QuantOps.create_inputs_fake_quant( inputs=scaled_activations, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, bounds=bounds * act_scale, prec=prec, half_shift=False)) onp.testing.assert_array_equal(activations * act_scale, scaled_activations) @parameterized.named_parameters( dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8) ) def test_per_feature_dim_scale_invariance_pos_activation_quantization( self, prec): # Scaling each channel of activations by a different power of 2 and upper # bound with same scale, should scale the respective channel of output by # the same scale. activations = random.uniform(random.PRNGKey(0), (3, 4)) act_scale = 2**jnp.arange(4) scaled_activations = activations * act_scale[jnp.newaxis, :] upper_bound = 6.0 * jnp.ones((3, 4), jnp.float32) act_quant_ops = QuantOps.create_positive(bounds=upper_bound, prec=prec) activations = act_quant_ops.fake_quant( activations, quantized_type=SCALE_DTYPE) scaled_act_quant_ops = QuantOps.create_positive( bounds=upper_bound * act_scale[jnp.newaxis, :], prec=prec) scaled_activations = scaled_act_quant_ops.fake_quant( scaled_activations, quantized_type=SCALE_DTYPE) onp.testing.assert_array_equal(activations * act_scale[jnp.newaxis, :], scaled_activations) @parameterized.named_parameters( dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8)) def test_int_positive_act_quantization(self, prec): # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec # quantizes correctly. upper_bound = 2**(prec - 3) activations = random.randint(random.PRNGKey(0), (10, 1), 0, upper_bound) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.positive, bounds=upper_bound, prec=prec, half_shift=False)) onp.testing.assert_array_equal(activations, rescaled_activations) @parameterized.named_parameters( dict(testcase_name='prec_2', prec=2), dict(testcase_name='prec_4', prec=4), dict(testcase_name='prec_8', prec=8) ) def test_int_symmetric_act_quantization(self, prec): # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1 # quantizes correctly. bounds = 2**(prec - 1) - 1 activations = random.randint(random.PRNGKey(0), (10, 1), -bounds, bounds) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, bounds=bounds, prec=prec, half_shift=False)) onp.testing.assert_array_equal(activations, rescaled_activations) @parameterized.named_parameters( dict( testcase_name='fp_prec_scaled', prec=QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-11, exp_max=4, sig_bits=3, ), ), ), dict( testcase_name='fp_prec_unscaled', prec=QuantOps.FloatQuant( is_scaled=False, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-11, exp_max=4, sig_bits=3, ), ), ), dict( testcase_name='int_prec', prec=4.0, ), ) def test_no_quantization(self, prec): # If initial_bound==-1 when using GetBounds, then create_inputs_fake_quant # should be a no-op. inputs = jnp.array([[.3, 1.4], [-5.2, 4.0]]) bounds = get_bounds.GetBounds.Hyper( initial_bound=-1, stddev_coeff=1, absdev_coeff=0, mix_coeff=1, reset_stats=True, ema_coeff=None, use_cams=False, granularity=quant_config.QuantGranularity.per_tensor) hparams = quantization.QuantOps.ActHParams( input_distribution='symmetric', bounds=bounds, prec=prec, half_shift=False) # The call to create_inputs_fake_quant has to occur from within a Flax # module since it calls GetBounds, which is itself a Flax module. # Thus we create a wrapper module for testing. class TestModule(nn.Module): hparams: quantization.QuantOps.ActHParams @nn.compact def __call__(self, inputs): return quantization.QuantOps.create_inputs_fake_quant( inputs, hparams=hparams, get_bounds_params=GetBounds.Params( update_stats=True, update_bounds=False)) test_module = TestModule(hparams=hparams) state = test_module.init(jax.random.PRNGKey(0), inputs=inputs) inputs_after_fake_quant, _ = test_module.apply( state, inputs=inputs, mutable=True) onp.testing.assert_array_equal(inputs, inputs_after_fake_quant)
from aqt.jax import quantization from aqt.jax import shape_utils from aqt.jax import test_utils from aqt.jax.quantization import QuantOps from aqt.jax.quantization import QuantType FLAGS = flags.FLAGS # fp-1-4-3 # 1: sign # 4: number of exponent-bits, (bias = 11), range: -11, ..., 4 # 3: number of significand-bits (excluding hidden-bit) fp143_scaled = QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-11, exp_max=4, sig_bits=3, ), ) fp143_unscaled = QuantOps.FloatQuant( is_scaled=False, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-11, exp_max=4, sig_bits=3, ), ) class ConvAqtTest(parameterized.TestCase): """Tests for ConvAqt layer."""