def test_scale_invariance_signed_activation_quantization(self, prec): # Scaling activation by power of 2 and bounds by same factor, # should scale the output by the same scale. activations = random.uniform(random.PRNGKey(0), (10, 1)) act_scale = 8. scaled_activations = activations * act_scale bounds = 6. activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) scaled_activations = QuantOps.create_inputs_fake_quant( inputs=scaled_activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds * act_scale, prec=prec)) onp.testing.assert_array_equal(activations * act_scale, scaled_activations)
def __call__(self, lhs_act, rhs_act, lhs_prec, rhs_prec): get_bounds_hyper = get_bounds.GetBounds.Hyper( initial_bound=10.0, stddev_coeff=0, absdev_coeff=0, mix_coeff=0, granularity=quant_config.QuantGranularity.per_tensor) lhs_act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=get_bounds_hyper, prec=lhs_prec, half_shift=False) rhs_act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=get_bounds_hyper, prec=rhs_prec, half_shift=False) lhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=False, update_bounds=False, module_name='lhs') rhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=False, update_bounds=False, module_name='rhs') output = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_act_hparams, rhs_act_hparams=rhs_act_hparams, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=QuantType.aqt, lhs_get_bounds_params=lhs_get_bounds_params, rhs_get_bounds_params=rhs_get_bounds_params) return output
def test_lax_dot_has_integer_inputs_in_dynamic_dot_general( self, mock_dot_general, lhs_distribution, rhs_distribution): lhs_params = QuantOps.ActHParams(input_distribution=lhs_distribution, bounds=2.0, prec=4) rhs_params = QuantOps.ActHParams(input_distribution=rhs_distribution, bounds=1.5, prec=4) lhs_act = self.lhs if lhs_distribution == 'positive': lhs_act = jnp.abs(lhs_act) rhs_act = self.rhs if rhs_distribution == 'positive': rhs_act = jnp.abs(rhs_act) quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=QuantType.aqt) lhs_inputs, rhs_inputs = mock_dot_general.call_args[0] self.assert_is_integer_in_range(lhs_inputs, prec=4, distribution=lhs_distribution) self.assert_is_integer_in_range(rhs_inputs, prec=4, distribution=rhs_distribution)
def test_quantized_dynamic_dot_general(self, lhs_prec, rhs_prec): lhs_bounds = 2.0 rhs_bounds = 1.5 lhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=lhs_bounds, prec=lhs_prec) rhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=rhs_bounds, prec=rhs_prec) def quantized_matmul(quant_type): return quantization.quantized_dynamic_dot_general( lhs_act=self.lhs, rhs_act=self.rhs, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=quant_type) aqt_result = quantized_matmul(QuantType.aqt) fakequant_result = quantized_matmul(QuantType.fake_quant) onp.testing.assert_allclose( aqt_result, fakequant_result, rtol=1e-2, err_msg='AQT and fakequant significantly disagree')
def test_quantized_dot_aqt(self, act_bounds, weight_prec, weight_axis): # With a high enough precision, we expect results from fakequant and AQT to # be very similar. weight_params = QuantOps.WeightParams(prec=weight_prec, axis=weight_axis) if act_bounds is None: act_params = None else: act_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=jnp.array(act_bounds), prec=16) def quantized_matmul(quant_type): return quantization.quantized_dot(w=self.rhs, act=self.lhs, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True) aqt_result = quantized_matmul(QuantType.aqt) fakequant_result = quantized_matmul(QuantType.fake_quant) onp.testing.assert_allclose( aqt_result, fakequant_result, rtol=1e-2, err_msg='AQT and fakequant significantly disagree')
def test_quantized_dynamic_dot_general_should_call_inputs_quantization( self, mock_act_fq, lhs_act_prec, rhs_act_prec, strategy=QuantType.fake_quant): mock_act_fq.side_effect = lambda inputs, hparams, get_bounds_params: inputs # pylint: disable=g-long-ternary lhs_act_hparams = QuantOps.ActHParams( bounds=6., prec=lhs_act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if lhs_act_prec else None rhs_act_hparams = QuantOps.ActHParams( bounds=6., prec=rhs_act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if rhs_act_prec else None # pylint: enable=g-long-ternary get_bounds_params = GetBounds.Params(update_stats=False, update_bounds=False) quantization.quantized_dynamic_dot_general( lhs_act=self.lhs_act, rhs_act=self.rhs_act, quant_type=strategy, dot_dimension_numbers=self.dimension_numbers, lhs_act_hparams=lhs_act_hparams, lhs_get_bounds_params=get_bounds_params, rhs_act_hparams=rhs_act_hparams, rhs_get_bounds_params=get_bounds_params, ) calls = [] for prec in [lhs_act_prec, rhs_act_prec]: if prec is not None: act_hparams = QuantOps.ActHParams(bounds=6., prec=prec, input_distribution=mock.ANY, half_shift=False) calls.append( mock.call(mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)) self.assertLen(calls, mock_act_fq.call_count) mock_act_fq.assert_has_calls(calls, any_order=True)
def test_dynamic_quantized_dot_general_raises_with_mixed_dtype(self): lhs_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=2.0, prec=4, half_shift=False) rhs_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=1.5, prec=4, half_shift=False) lhs_act = self.lhs.astype(jnp.bfloat16) rhs_act = self.rhs.astype(jnp.float32) with self.assertRaises(TypeError): quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1,), (0,)), ((), ())), quant_type=QuantType.aqt)
def test_attributes_create_acts_op_fp( self, act_distribution, use_hparams_bounds, ): inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) fp_quant = QuantOps.FloatQuant( is_scaled=True, fp_spec=QuantOps.FloatQuant.FloatPrec( exp_min=-15, exp_max=15, sig_bits=2, ), ) if use_hparams_bounds: bounds = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=1, absdev_coeff=0, mix_coeff=1, reset_stats=True, ema_coeff=None, use_cams=False, granularity=quant_config.QuantGranularity.per_tensor) else: bounds = 6.0 hparams = QuantOps.ActHParams( input_distribution=act_distribution, bounds=bounds, prec=fp_quant, half_shift=False) class TestModule(nn.Module): hparams: QuantOps.ActHParams @nn.compact def __call__(self, inputs): return QuantOps.create_input_ops( inputs, hparams=hparams, get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False)) test_module = TestModule(hparams=hparams) state = test_module.init(jax.random.PRNGKey(0), inputs=inputs) act_quant_op = test_module.apply(state, inputs=inputs) act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype) act_quant_expected = fp_cast.downcast_sat_ftz( act_scaled, fp_quant.fp_spec.exp_min, fp_quant.fp_spec.exp_max, fp_quant.fp_spec.sig_bits, ) act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE) onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated)
def test_dynamic_quantized_dot_general_has_correct_dtype( self, input_dtype, act_prec, quant_type): lhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=2.0, prec=act_prec) rhs_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=1.5, prec=act_prec) lhs_act = self.lhs.astype(input_dtype) rhs_act = self.rhs.astype(input_dtype) output = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_params, rhs_act_hparams=rhs_params, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=quant_type) self.assertEqual(output.dtype, input_dtype)
def test_quantized_dot_raises_with_mixed_dtype(self, quant_type): weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=jnp.array([[3.0, 1.5]]), prec=4) act = self.lhs.astype(jnp.bfloat16) w = self.rhs.astype(jnp.float32) with self.assertRaises(TypeError): quantization.quantized_dot(w=w, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True)
def test_quantized_dot_no_quant(self): act_hparams = QuantOps.ActHParams(input_distribution='symmetric', bounds=-1.0, prec=4) weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act = jnp.array([[-5.0]]) w = jnp.array([[-4.99]]) res = quantization.quantized_dot(w=w, act=act, quant_type=quantization.QuantType.aqt, weight_params=weight_params, act_hparams=act_hparams, get_bounds_params=None, prefer_int8_to_int32_dot=True) onp.testing.assert_allclose(res, act * w)
def test_quantized_dynamic_dot_general_no_quant(self): act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=-1.0, prec=4, half_shift=False) lhs_act = jnp.array([[-5.0]]) rhs_act = jnp.array([[-4.99]]) res = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, quant_type=quantization.QuantType.aqt, lhs_act_hparams=act_hparams, rhs_act_hparams=act_hparams, lhs_get_bounds_params=None, rhs_get_bounds_params=None, dot_dimension_numbers=(((1,), (0,)), ((), ()))) onp.testing.assert_allclose(res, lhs_act * rhs_act)
def test_quantized_dot_has_correct_dtype(self, input_dtype, act_prec, quant_type): weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=jnp.array([[3.0, 1.5]]), prec=act_prec) act = self.lhs.astype(input_dtype) w = self.rhs.astype(input_dtype) output = quantization.quantized_dot(w=w, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True) self.assertEqual(output.dtype, input_dtype)
def test_int_symmetric_act_quantization(self, prec): # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1 # quantizes correctly. bounds = 2**(prec - 1) - 1 activations = random.randint(random.PRNGKey(0), (10, 1), -bounds, bounds) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations)
def test_int_positive_act_quantization(self, prec): # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec # quantizes correctly. upper_bound = 2**(prec - 3) activations = random.randint(random.PRNGKey(0), (10, 1), 0, upper_bound) rescaled_activations = QuantOps.create_inputs_fake_quant( inputs=activations, get_bounds_params=GetBounds.Params(update_stats=False, update_bounds=False), hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.positive, bounds=upper_bound, prec=prec)) onp.testing.assert_array_equal(activations, rescaled_activations)
def test_inputs_scale_shape_is_expected(self): # Inputs quantization inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4)))) bounds = 6.0 expected_inputs_scale_shape = () _ = QuantOps.create_inputs_fake_quant( inputs=inputs, hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams. InputDistribution.symmetric, bounds=bounds, prec=8.0), get_bounds_params=GetBounds.Params( update_stats=False, update_bounds=False, expected_bounds_shape=expected_inputs_scale_shape))
def test_lax_dot_has_integer_inputs_in_quantized_dot( self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot, prec): weight_params = QuantOps.WeightParams(prec=prec, axis=(0, ), half_shift=False) act_params = QuantOps.ActHParams(input_distribution=act_distribution, bounds=jnp.array([[3.0, 1.5]]), prec=prec, half_shift=False) act = self.lhs if act_distribution == 'positive': act = jnp.abs(act) # We need this context manager to stop Jax from trying to compile the arms # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always # try to compile the functions passed to `lax.cond`, even if outside of a # JITed context. JIT compilation is incompatible with using a mock for the # call to 'dot_general' because during compilation Jax will expect # 'dot_general' to return a tracer and will throw an error if it returns a # mock instead. By explicily using jax.disable_jit, Jax will not try to # compile the arms to lax.cond and so using a mock will work fine. with jax.disable_jit(): quantization.quantized_dot( w=self.rhs, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=QuantType.aqt, prefer_int8_to_int32_dot=prefer_int8_to_int32_dot) act_inputs, weight_inputs = mock_dot_general.call_args[0] self.assert_is_integer_in_range(act_inputs, prec=prec, distribution=act_distribution) self.assert_is_integer_in_range(weight_inputs, prec=prec, distribution='symmetric') if prefer_int8_to_int32_dot and not (act_distribution == 'positive' and prec == 8): expected_input_dtype = jnp.int8 else: expected_input_dtype = jnp.float32 self.assertEqual(act_inputs.dtype, expected_input_dtype) self.assertEqual(weight_inputs.dtype, expected_input_dtype)
def test_quantized_dot_general_should_call_weights_and_inputs_quantization( self, mock_act_fq, mock_w_fq, weight_prec, act_prec, strategy=QuantType.fake_quant): mock_w_fq.side_effect = lambda inputs, **_: inputs mock_act_fq.side_effect = lambda inputs, **_: inputs weight_params = QuantOps.WeightParams( prec=weight_prec, axis=None, half_shift=False) act_hparams = QuantOps.ActHParams( # pylint: disable=g-long-ternary bounds=6., prec=act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if act_prec else None get_bounds_params = GetBounds.Params( update_stats=False, update_bounds=False) quantization.quantized_dot( w=self.weight, act=self.act, quant_type=strategy, weight_params=weight_params, act_hparams=act_hparams, get_bounds_params=get_bounds_params, prefer_int8_to_int32_dot=True) quantized_type = strategy.to_jax_type() mock_w_fq.assert_called_with( mock.ANY, weight_params=weight_params, quantized_type=quantized_type, fake_dependency=mock.ANY) if act_hparams: mock_act_fq.assert_called_with( mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params) else: mock_act_fq.assert_not_called()
def test_quantized_dot_general_aqt(self, act_bounds, weight_prec, weight_axis): # With a high enough precision, we expect results from fakequant and AQT to # be very similar. weight_params = QuantOps.WeightParams( prec=weight_prec, axis=weight_axis, half_shift=False) if act_bounds is None: act_params = None else: act_params = QuantOps.ActHParams( input_distribution='symmetric', bounds=jnp.array(act_bounds), prec=16, half_shift=False) lhs_ndims_3 = jnp.array( fp32(2.0 * onp.random.uniform(0, 1.0, size=(4, 3, 2)))) def quantized_matmul(quant_type): return quantization.quantized_dot_general( w=self.rhs, act=lhs_ndims_3, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, dimension_numbers=(((lhs_ndims_3.ndim - 1,), (0,)), ((), ())), prefer_int8_to_int32_dot=True) aqt_result = quantized_matmul(QuantType.aqt) self.assertEqual(aqt_result.shape, (4, 3, 4)) fakequant_result = quantized_matmul(QuantType.fake_quant) onp.testing.assert_allclose( aqt_result, fakequant_result, rtol=1e-2, err_msg='AQT and fakequant significantly disagree')
class ComputeCostUtilsTest(parameterized.TestCase): def setUp(self): super(ComputeCostUtilsTest, self).setUp() self.rng_key = random.PRNGKey(0) def compare_hlo_instructions(self, hlo_no_annotation, hlo_w_annotation): """Compares two HLO models to check if they only differ in metadata info.""" instrs_n = [] instrs_w = [] # gather instructions from both HLO models for computation in hlo_no_annotation.computations: for instr in computation.instructions: instrs_n.append(instr) for computation in hlo_w_annotation.computations: for instr in computation.instructions: instrs_w.append(instr) self.assertEqual(len(instrs_n), len(instrs_w)) for i, _ in enumerate(instrs_n): # check instructions with the opcode 'convolution' # the metadata field for instrs_w and instrs_n should be different. if (instrs_n[i].opcode == 'convolution' and instrs_w[i].opcode == 'convolution'): self.assertNotEqual(instrs_n[i].metadata, instrs_w[i].metadata) # remove metadata op_type and op_name instrs_n[i].metadata.op_type = '' instrs_w[i].metadata.op_type = '' instrs_n[i].metadata.op_name = '' instrs_w[i].metadata.op_name = '' # compare the rest of the instructions. self.assertEqual(instrs_n[i], instrs_w[i]) class TestModelWith1Dense(nn.Module): """Test model with a single DenseAqt layer.""" @nn.compact def __call__(self, inputs, hparams, num_classes, dtype=jnp.float32): output = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=False, quant_context=quant_config.QuantContext( update_bounds=False, collect_acts_stats=False), paxis_name='batch', hparams=hparams, )(inputs, padding_mask=None) return output class TestModelWith1Conv(nn.Module): """Test model with a single ConvAqt layer.""" @nn.compact def __call__(self, inputs, hparams, kernel_size, num_filters, strides, dtype=jnp.float32): output = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=kernel_size, strides=strides, use_bias=False, dtype=dtype, train=False, quant_context=quant_config.QuantContext(update_bounds=False), paxis_name='batch', hparams=hparams)(inputs) return output class TestModelWith1DynamicMatmul(nn.Module): """Test model with a single dynamic matmul.""" @nn.compact def __call__(self, lhs_act, rhs_act, lhs_prec, rhs_prec): get_bounds_hyper = get_bounds.GetBounds.Hyper( initial_bound=10.0, stddev_coeff=0, absdev_coeff=0, mix_coeff=0, granularity=quant_config.QuantGranularity.per_tensor) lhs_act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=get_bounds_hyper, prec=lhs_prec, half_shift=False) rhs_act_hparams = QuantOps.ActHParams( input_distribution='symmetric', bounds=get_bounds_hyper, prec=rhs_prec, half_shift=False) lhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=False, update_bounds=False, module_name='lhs') rhs_get_bounds_params = get_bounds.GetBounds.Params( update_stats=False, update_bounds=False, module_name='rhs') output = quantization.quantized_dynamic_dot_general( lhs_act=lhs_act, rhs_act=rhs_act, lhs_act_hparams=lhs_act_hparams, rhs_act_hparams=rhs_act_hparams, dot_dimension_numbers=(((1, ), (0, )), ((), ())), quant_type=QuantType.aqt, lhs_get_bounds_params=lhs_get_bounds_params, rhs_get_bounds_params=rhs_get_bounds_params) return output @parameterized.named_parameters( # TestModelWith1Dense dict( testcase_name='single_dense_layer_bfloat16', modelclass=TestModelWith1Dense, input_shapes=[(1, 8)], model_kwargs={ 'num_classes': 2, 'hparams': aqt_flax_layers.DenseAqt.HParams( weight_prec=None, quant_type=QuantType.fake_quant, quant_act=None, weight_quant_granularity=quant_config.QuantGranularity.per_channel, weight_half_shift=False ), }, expected_compute_cost=8 * 2 * (16 * 16), expected_compute_cost_ratio=1.0, expected_compute_cost_linear=8 * 2 * (16), expected_compute_cost_ratio_linear=1.0, expected_memory_cost=8 * 2 * (16), expected_memory_cost_ratio=1.0, ), dict( testcase_name='single_dense_layer_w8_a8', modelclass=TestModelWith1Dense, input_shapes=[(1, 8)], model_kwargs={ 'num_classes': 2, 'hparams': aqt_flax_layers.DenseAqt.HParams( weight_prec=8, quant_type=QuantType.fake_quant, quant_act=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.positive, prec=8, bounds=1.0, half_shift=False, ), weight_quant_granularity=quant_config.QuantGranularity.per_channel, weight_half_shift=False ), }, expected_compute_cost=8 * 2 * (8 * 8), expected_compute_cost_ratio=0.25, expected_compute_cost_linear=8 * 2 * (8), expected_compute_cost_ratio_linear=0.5, expected_memory_cost=8 * 2 * (8), expected_memory_cost_ratio=0.5, ), # TestModelWith1Conv dict( testcase_name='single_conv_layer_bfloat16', modelclass=TestModelWith1Conv, input_shapes=[(1, 8, 8, 3)], model_kwargs={ 'kernel_size': (3, 3), 'num_filters': 16, 'strides': (1, 1), 'hparams': aqt_flax_layers.ConvAqt.HParams( weight_prec=None, quant_type=QuantType.fake_quant, quant_act=None, weight_half_shift=False, ), }, expected_compute_cost=(3 * 3) * (8 * 8) * 3 * 16 * (16 * 16), expected_compute_cost_ratio=1.0, expected_compute_cost_linear=(3 * 3) * (8 * 8) * 3 * 16 * (16), expected_compute_cost_ratio_linear=1.0, expected_memory_cost=(3 * 3) * 3 * 16 * (16), expected_memory_cost_ratio=1.0, ), dict( testcase_name='single_conv_layer_bfloat16_strided', modelclass=TestModelWith1Conv, input_shapes=[(1, 8, 8, 3)], model_kwargs={ 'kernel_size': (3, 3), 'num_filters': 16, 'strides': (4, 2), 'hparams': aqt_flax_layers.ConvAqt.HParams( weight_prec=None, quant_type=QuantType.fake_quant, quant_act=None, weight_half_shift=False, ), }, expected_compute_cost=(3 * 3) * ((8 / 4) * (8 / 2)) * 3 * 16 * (16 * 16), expected_compute_cost_ratio=1.0, expected_compute_cost_linear=(3 * 3) * ((8 / 4) * (8 / 2)) * 3 * 16 * (16), expected_compute_cost_ratio_linear=1.0, expected_memory_cost=(3 * 3) * 3 * 16 * (16), expected_memory_cost_ratio=1.0, ), dict( testcase_name='single_conv_layer_bfloat16_3d', modelclass=TestModelWith1Conv, input_shapes=[(1, 8, 8, 8, 3)], model_kwargs={ 'kernel_size': (3, 3, 3), 'num_filters': 16, 'strides': (1, 1, 1), 'hparams': aqt_flax_layers.ConvAqt.HParams( weight_prec=None, quant_type=QuantType.fake_quant, quant_act=None, weight_half_shift=False, ), }, expected_compute_cost=(3 * 3 * 3) * (8 * 8 * 8) * 3 * 16 * (16 * 16), expected_compute_cost_ratio=1.0, expected_compute_cost_linear=(3 * 3 * 3) * (8 * 8 * 8) * 3 * 16 * (16), expected_compute_cost_ratio_linear=1.0, expected_memory_cost=(3 * 3 * 3) * 3 * 16 * (16), expected_memory_cost_ratio=1.0, ), dict( testcase_name='single_conv_layer_w4_a2', modelclass=TestModelWith1Conv, input_shapes=[(1, 8, 8, 3)], model_kwargs={ 'kernel_size': (3, 3), 'num_filters': 16, 'strides': (1, 1), 'hparams': aqt_flax_layers.ConvAqt.HParams( weight_prec=4, quant_type=QuantType.fake_quant, quant_act=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.positive, prec=2, bounds=1.0, half_shift=False, ), weight_half_shift=False, ), }, expected_compute_cost=(3 * 3) * (8 * 8) * 3 * 16 * (4 * 2), expected_compute_cost_ratio=0.03125, expected_compute_cost_linear=(3 * 3) * (8 * 8) * 3 * 16 * (4), expected_compute_cost_ratio_linear=0.25, expected_memory_cost=(3 * 3) * 3 * 16 * (4), expected_memory_cost_ratio=0.25, ), # TestModelWith1DynamicMatmul dict( testcase_name='single_dynamic_matmul_layer_bfloat16', modelclass=TestModelWith1DynamicMatmul, input_shapes=[(1, 8), (8, 1)], model_kwargs={'lhs_prec': None, 'rhs_prec': None}, expected_compute_cost=8 * (16 * 16), expected_compute_cost_ratio=1.0, expected_compute_cost_linear=8 * (16), expected_compute_cost_ratio_linear=1.0, expected_memory_cost=0, expected_memory_cost_ratio=1.0, ), dict( testcase_name='single_dynamic_matmul_layer_l8_r8', modelclass=TestModelWith1DynamicMatmul, input_shapes=[(1, 8), (8, 1)], model_kwargs={'lhs_prec': 8, 'rhs_prec': 8}, expected_compute_cost=8 * (8 * 8), expected_compute_cost_ratio=0.25, expected_compute_cost_linear=8 * 8, expected_compute_cost_ratio_linear=0.5, expected_memory_cost=0, expected_memory_cost_ratio=1.0, ), dict( testcase_name='single_dynamic_matmul_layer_l8_r4', modelclass=TestModelWith1DynamicMatmul, input_shapes=[(1, 8), (8, 1)], model_kwargs={'lhs_prec': 8, 'rhs_prec': 4}, expected_compute_cost=8 * (8 * 4), expected_compute_cost_ratio=0.125, expected_compute_cost_linear=8 * (8), expected_compute_cost_ratio_linear=0.5, expected_memory_cost=0, expected_memory_cost_ratio=1.0, ), ) # pylint: disable=line-too-long def test_estimate_simple_model_cost( self, modelclass, input_shapes, model_kwargs, expected_compute_cost, expected_compute_cost_ratio, expected_compute_cost_linear, expected_compute_cost_ratio_linear, expected_memory_cost, expected_memory_cost_ratio): module = modelclass() input_shapes_with_type = [(sh, jnp.float32) for sh in input_shapes] dummy_inputs = [ jnp.ones(input_shape, dtype=dtype) for (input_shape, dtype) in input_shapes_with_type ] init_state = module.init(random.PRNGKey(0), *dummy_inputs, **model_kwargs) hlo_proto = hlo_utils.load_hlo_proto_from_model( module, init_state, input_shapes, **model_kwargs) compute_result = compute_cost_utils.estimate_compute_cost(hlo_proto) memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto) logging.info('compute cost result is %s', compute_result) logging.info('memory cost result is %s', memory_result) self.assertEqual(compute_result['compute_cost'], expected_compute_cost) self.assertEqual(memory_result['memory_cost'], expected_memory_cost) self.assertEqual(compute_result['compute_cost_ratio_to_bfloat16'], expected_compute_cost_ratio) self.assertEqual(memory_result['memory_cost_ratio_to_bfloat16'], expected_memory_cost_ratio) self.assertEqual(compute_result['compute_cost_linear'], expected_compute_cost_linear) self.assertEqual( compute_result['compute_cost_ratio_to_bfloat16_linear'], expected_compute_cost_ratio_linear) @parameterized.named_parameters( # TestModelWith1Dense dict( testcase_name='single_dense_layer_bfloat16_batch_size', modelclass=TestModelWith1Dense, input_shape_per_sample=(16, ), model_kwargs={ 'num_classes': 20, 'hparams': aqt_flax_layers.DenseAqt.HParams( weight_prec=None, quant_act=None, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity. per_channel, weight_half_shift=False) }, ), # TestModelWith1Conv dict( testcase_name='single_conv_layer_bfloat16_batch_size', modelclass=TestModelWith1Conv, input_shape_per_sample=(16, 16, 3), model_kwargs={ 'kernel_size': (3, 3), 'num_filters': 16, 'strides': (2, 2), 'hparams': aqt_flax_layers.ConvAqt.HParams( weight_prec=None, quant_act=None, quant_type=QuantType.fake_quant, weight_half_shift=False, ) }, ), ) def test_batch_size_has_no_effect_on_cost(self, modelclass, input_shape_per_sample, model_kwargs): expected_compute_cost = None expected_memory_cost = None batch_size_list = [32, 64, 128, 256, 512, 1024] module = modelclass() # Sweep over the batch size list for batch_size in batch_size_list: input_shape = (batch_size, ) + input_shape_per_sample init_state = module.init(random.PRNGKey(0), jnp.ones(input_shape, jnp.float32), **model_kwargs) hlo_proto = hlo_utils.load_hlo_proto_from_model( module, init_state, [input_shape], **model_kwargs) del init_state compute_result = compute_cost_utils.estimate_compute_cost( hlo_proto) memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto) # Save the first cost and compare it with the rest if expected_compute_cost is None: expected_compute_cost = compute_result['compute_cost'] else: self.assertEqual(compute_result['compute_cost'], expected_compute_cost) if expected_memory_cost is None: expected_memory_cost = memory_result['memory_cost'] else: self.assertEqual(memory_result['memory_cost'], expected_memory_cost) @parameterized.named_parameters( dict(testcase_name='quant_8bit', weight_prec=8), dict(testcase_name='quant_4bit', weight_prec=4), ) def test_check_value_inside_and_outside_of_context_conv_general( self, weight_prec): original_op_name = 'conv_general_dilated' # The 'name' in primitive should change in the context in 'flax_layers' # if the context is enabled self.assertEqual(original_op_name, lax_convolution.conv_general_dilated_p.name) with compute_cost_utils.ConvMetadataMonkeyPatch( weight_prec=weight_prec, act_prec=None): self.assertNotEqual(original_op_name, lax_convolution.conv_general_dilated_p.name) self.assertEqual(original_op_name, lax_convolution.conv_general_dilated_p.name) @parameterized.named_parameters( dict(testcase_name='quant_8bit', weight_prec=8, acts_prec=8), dict(testcase_name='quant_4bit', weight_prec=4, acts_prec=4), ) def test_annotation_only_changes_hlo_metadata_conv(self, weight_prec, acts_prec): FLAGS.metadata_enabled = False quant_act = quantization.QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, prec=acts_prec, bounds=1.0, half_shift=False) input_shape = (1, 8, 8, 3) module_no_annotation = aqt_flax_layers.ConvAqt( features=4, kernel_size=(3, 3), padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=aqt_flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_half_shift=False), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) init_state = module_no_annotation.init( self.rng_key, jnp.ones(input_shape, jnp.float32)) output_no_annotation = module_no_annotation.apply( init_state, jnp.ones(input_shape)) hlo_no_annotation = hlo_utils.load_hlo_proto_from_model( module_no_annotation, init_state, [input_shape]) del init_state FLAGS.metadata_enabled = True module_w_annotation = aqt_flax_layers.ConvAqt( features=4, kernel_size=(3, 3), padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=aqt_flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_half_shift=False), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) init_state = module_w_annotation.init( self.rng_key, jnp.ones(input_shape, jnp.float32)) output_w_annotation = module_w_annotation.apply( init_state, jnp.ones(input_shape)) hlo_w_annotation = hlo_utils.load_hlo_proto_from_model( module_w_annotation, init_state, [input_shape]) del init_state onp.testing.assert_array_equal(output_no_annotation, output_w_annotation) self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation) @parameterized.named_parameters( dict(testcase_name='quant_8bit', weight_prec=8), dict(testcase_name='quant_4bit', weight_prec=4), ) def test_check_value_inside_and_outside_of_context_dot_general( self, weight_prec): original_op_name = 'dot_general' # The 'name' in primitive should change in the context in 'flax_layers' # if the context is enabled. self.assertEqual(original_op_name, lax.dot_general_p.name) with compute_cost_utils.DotMetadataMonkeyPatch(lhs_prec=None, rhs_prec=weight_prec, rhs_is_weight=True): self.assertNotEqual(original_op_name, lax.dot_general_p.name) self.assertEqual(original_op_name, lax.dot_general_p.name) @parameterized.named_parameters( dict( testcase_name='quant_8bit', weight_prec=8, acts_prec=8, ), ) def test_annotation_only_changes_hlo_metadata_dense( self, weight_prec, acts_prec): FLAGS.metadata_enabled = False quant_act = quantization.QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, prec=acts_prec, bounds=1.0, half_shift=False) input_shape = (1, 16) module_no_annotation = aqt_flax_layers.DenseAqt( features=4, use_bias=False, quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), paxis_name='batch', train=False, hparams=aqt_flax_layers.DenseAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity. per_channel, weight_half_shift=False), dtype=jnp.float32) init_state = module_no_annotation.init(self.rng_key, jnp.ones( input_shape, jnp.float32), padding_mask=None) output_no_annotation = module_no_annotation.apply( init_state, jnp.ones(input_shape), padding_mask=None) hlo_no_annotation = hlo_utils.load_hlo_proto_from_model( module_no_annotation, init_state, [input_shape], padding_mask=None) del init_state FLAGS.metadata_enabled = True module_w_annotation = aqt_flax_layers.DenseAqt( features=4, use_bias=False, paxis_name='batch', train=False, quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), dtype=jnp.float32, hparams=aqt_flax_layers.DenseAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity. per_channel, weight_half_shift=False), ) init_state = module_w_annotation.init(self.rng_key, jnp.ones(input_shape, jnp.float32), padding_mask=None) output_w_annotation = module_w_annotation.apply(init_state, jnp.ones(input_shape), padding_mask=None) hlo_w_annotation = hlo_utils.load_hlo_proto_from_model( module_w_annotation, init_state, [input_shape], padding_mask=None) del init_state onp.testing.assert_array_equal(output_no_annotation, output_w_annotation) self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
class AttnActsMatmulQuantTest(parameterized.TestCase): def construct_hparams(self, attn_act_q, attn_act_k, attn_act_probs, attn_act_v): dense = flax_layers.DenseAqt.HParams( weight_prec=None, quant_act=None, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity.per_channel, weight_half_shift=False) return flax_attention.MultiHeadDotProductAttentionAqt.HParams( dense_kqv=dense, dense_out=dense, attn_acts=flax_attention.DotProductAttnHParams( attn_act_q=attn_act_q, attn_act_k=attn_act_k, attn_act_probs=attn_act_probs, attn_act_v=attn_act_v, quant_type=QuantType.fake_quant, softmax=SoftmaxHParams(None, None, None))) @parameterized.named_parameters( dict(testcase_name='float', attn_act_q=None, attn_act_k=None, attn_act_probs=None, attn_act_v=None, update_bounds=False, paxis_name=None, train=False), dict(testcase_name='quant_q', attn_act_q=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution. symmetric, prec=8, bounds=1, half_shift=False), attn_act_k=None, attn_act_probs=None, attn_act_v=None, update_bounds=False, paxis_name='batch', train=True), dict(testcase_name='quant_qk', attn_act_q=None, attn_act_k=None, attn_act_probs=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution. symmetric, prec=8, bounds=1.0, half_shift=False), attn_act_v=None, update_bounds=False, paxis_name='batch', train=True), dict(testcase_name='quant_k', attn_act_q=None, attn_act_k=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution. symmetric, prec=4, bounds=2, half_shift=False), attn_act_probs=None, attn_act_v=None, update_bounds=False, paxis_name=None, train=True), dict(testcase_name='quant_v', attn_act_q=None, attn_act_k=None, attn_act_probs=None, attn_act_v=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution. symmetric, prec=2, bounds=3, half_shift=False), update_bounds=True, paxis_name='batch', train=False), dict(testcase_name='quant_all_aa', attn_act_q=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution. symmetric, prec=8, bounds=1, half_shift=False), attn_act_k=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution. symmetric, prec=4, bounds=2, half_shift=False), attn_act_probs=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution. symmetric, prec=8, bounds=1.0, half_shift=False), attn_act_v=QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution. symmetric, prec=2, bounds=3, half_shift=False), update_bounds=True, paxis_name=None, train=True), ) @unittest.mock.patch.object(QuantOps, 'create_inputs_fake_quant') def test_self_attention_act_quant_should_call_quant_ops( self, mock_inputs_fake_quant, attn_act_q, attn_act_k, attn_act_probs, attn_act_v, update_bounds, paxis_name, train): mock_inputs_fake_quant.side_effect = ( lambda inputs, hparams, get_bounds_params: inputs) rng = random.PRNGKey(0) x = jnp.ones((4, 3, 7)) hparams = self.construct_hparams(attn_act_q, attn_act_k, attn_act_probs, attn_act_v) sa_module = flax_attention.SelfAttentionAqt( hparams=hparams, num_heads=4, quant_context=quant_config.QuantContext( update_bounds=update_bounds, collect_acts_stats=False), train=train, paxis_name=paxis_name, attention_axis=None, qkv_features=8, kernel_init=initializers.ones, bias_init=initializers.zeros, causal_mask=False, dtype=jnp.float32, dropout_rate=0.0, deterministic=False, decode=False) sa_module.init(rng, x, padding_mask=None) calls = [] for hparam in [attn_act_q, attn_act_k, attn_act_probs, attn_act_v]: if hparam is not None: calls.append( unittest.mock.call( unittest.mock.ANY, hparams=hparam, get_bounds_params=get_bounds.GetBounds.Params( update_stats=train, update_bounds=update_bounds, paxis_name=paxis_name, mask=unittest.mock.ANY, module_name=unittest.mock.ANY))) mock_inputs_fake_quant.assert_has_calls(calls, any_order=True) self.assertLen(calls, mock_inputs_fake_quant.call_count)