def test_quantized_dot_raises_with_mixed_dtype(self, quant_type): weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=jnp.array([[3.0, 1.5]]), prec=4) act = self.lhs.astype(jnp.bfloat16) w = self.rhs.astype(jnp.float32) with self.assertRaises(TypeError): quantization.quantized_dot(w=w, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True)
def quantized_matmul(quant_type): return quantization.quantized_dot(w=self.rhs, act=self.lhs, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True)
def test_lax_dot_has_integer_inputs_in_quantized_dot( self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot, prec): weight_params = QuantOps.WeightParams(prec=prec, axis=(0, ), half_shift=False) act_params = QuantOps.ActHParams(input_distribution=act_distribution, bounds=jnp.array([[3.0, 1.5]]), prec=prec, half_shift=False) act = self.lhs if act_distribution == 'positive': act = jnp.abs(act) # We need this context manager to stop Jax from trying to compile the arms # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always # try to compile the functions passed to `lax.cond`, even if outside of a # JITed context. JIT compilation is incompatible with using a mock for the # call to 'dot_general' because during compilation Jax will expect # 'dot_general' to return a tracer and will throw an error if it returns a # mock instead. By explicily using jax.disable_jit, Jax will not try to # compile the arms to lax.cond and so using a mock will work fine. with jax.disable_jit(): quantization.quantized_dot( w=self.rhs, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=QuantType.aqt, prefer_int8_to_int32_dot=prefer_int8_to_int32_dot) act_inputs, weight_inputs = mock_dot_general.call_args[0] self.assert_is_integer_in_range(act_inputs, prec=prec, distribution=act_distribution) self.assert_is_integer_in_range(weight_inputs, prec=prec, distribution='symmetric') if prefer_int8_to_int32_dot and not (act_distribution == 'positive' and prec == 8): expected_input_dtype = jnp.int8 else: expected_input_dtype = jnp.float32 self.assertEqual(act_inputs.dtype, expected_input_dtype) self.assertEqual(weight_inputs.dtype, expected_input_dtype)
def test_quantized_dot_general_should_call_weights_and_inputs_quantization( self, mock_act_fq, mock_w_fq, weight_prec, act_prec, strategy=QuantType.fake_quant): mock_w_fq.side_effect = lambda inputs, **_: inputs mock_act_fq.side_effect = lambda inputs, **_: inputs weight_params = QuantOps.WeightParams( prec=weight_prec, axis=None, half_shift=False) act_hparams = QuantOps.ActHParams( # pylint: disable=g-long-ternary bounds=6., prec=act_prec, input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, half_shift=False) if act_prec else None get_bounds_params = GetBounds.Params( update_stats=False, update_bounds=False) quantization.quantized_dot( w=self.weight, act=self.act, quant_type=strategy, weight_params=weight_params, act_hparams=act_hparams, get_bounds_params=get_bounds_params, prefer_int8_to_int32_dot=True) quantized_type = strategy.to_jax_type() mock_w_fq.assert_called_with( mock.ANY, weight_params=weight_params, quantized_type=quantized_type, fake_dependency=mock.ANY) if act_hparams: mock_act_fq.assert_called_with( mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params) else: mock_act_fq.assert_not_called()
def test_quantized_dot_no_quant(self): act_hparams = QuantOps.ActHParams(input_distribution='symmetric', bounds=-1.0, prec=4) weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act = jnp.array([[-5.0]]) w = jnp.array([[-4.99]]) res = quantization.quantized_dot(w=w, act=act, quant_type=quantization.QuantType.aqt, weight_params=weight_params, act_hparams=act_hparams, get_bounds_params=None, prefer_int8_to_int32_dot=True) onp.testing.assert_allclose(res, act * w)
def test_quantized_dot_has_correct_dtype(self, input_dtype, act_prec, quant_type): weight_params = QuantOps.WeightParams(prec=4, axis=(0, )) act_params = QuantOps.ActHParams(input_distribution='symmetric', bounds=jnp.array([[3.0, 1.5]]), prec=act_prec) act = self.lhs.astype(input_dtype) w = self.rhs.astype(input_dtype) output = quantization.quantized_dot(w=w, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=quant_type, prefer_int8_to_int32_dot=True) self.assertEqual(output.dtype, input_dtype)
def __call__( self, inputs, *, padding_mask, ): """Applies a linear transformation to the inputs with optional quantization. If weight_prec is not None, scales and quantizes weights to signed int with weight_prec bits. Args: inputs: The nd-array to be transformed. padding_mask: boolean tensor of the same shape as 'inputs' specifying which values of 'inputs' to use as part of the bounds calculation. 'True' indicates the corresponding value from 'inputs' should be used. If None, all values are used. Returns: The transformed input. """ batch_size, channel_size = inputs.shape # pylint: disable=unused-variable if padding_mask is not None: shape_utils.assert_shapes_equal(padding_mask.shape, (batch_size, 1)) # TODO(wanglisa): Replace fake quant with AQT. if self.quant_context.collect_acts_stats: stats_tag.StatsTag(channel_axis=-1, name='inputs', update_stats=self.train)(inputs, mask=padding_mask) hparams = self.hparams if (hparams.weight_prec is not None and isinstance(hparams.weight_prec, int) and hparams.weight_prec > 8): raise NotImplementedError( 'If you want to use more than 8bits for quantization, please revisit ' 'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.' ) jax_precision = jax.lax.Precision.DEFAULT kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) inputs = jnp.asarray(inputs, self.dtype) kernel = jnp.asarray(kernel, self.dtype) get_bounds_params = get_bounds.GetBounds.Params( update_bounds=self.quant_context.update_bounds, update_stats=self.train, paxis_name=self.paxis_name, mask=padding_mask) weight_quant_granularity = hparams.weight_quant_granularity # kernel.shape = (channels_in, channels_out) if weight_quant_granularity == quant_config.QuantGranularity.per_channel: # Compute scale factors by reducing over the rows of the weight matrix, # resulting in one scale factor per column. This results in one scale # factor per output channel. expected_scale_shape = (1, self.features) weight_quant_axis = (0, ) elif weight_quant_granularity == quant_config.QuantGranularity.per_tensor: # Compute a single scale factor for the entire weight matrix. expected_scale_shape = (1, 1) weight_quant_axis = None else: raise ValueError( f'Invalid quantization granularity {weight_quant_granularity}.' ) weight_params = QuantOps.WeightParams( prec=hparams.weight_prec, axis=weight_quant_axis, expected_scale_shape=expected_scale_shape) # TODO(wanglisa): add option to control when scale is being recomputed # matmul y = quantization.quantized_dot(act=inputs, w=kernel, quant_type=hparams.quant_type, weight_params=weight_params, act_hparams=hparams.quant_act, get_bounds_params=get_bounds_params, dot_precision=jax_precision, prefer_int8_to_int32_dot=self. quant_context.prefer_int8_to_int32_dot) # bias if self.use_bias: bias = self.param('bias', self.bias_init, (self.features, )) # (batch_size, features) y = y + bias[jnp.newaxis, :] shape_utils.assert_shapes_equal(y.shape, (batch_size, self.features)) return y