def test_quantized_dot_raises_with_mixed_dtype(self, quant_type):
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=4)
     act = self.lhs.astype(jnp.bfloat16)
     w = self.rhs.astype(jnp.float32)
     with self.assertRaises(TypeError):
         quantization.quantized_dot(w=w,
                                    act=act,
                                    weight_params=weight_params,
                                    act_hparams=act_params,
                                    get_bounds_params=None,
                                    quant_type=quant_type,
                                    prefer_int8_to_int32_dot=True)
 def quantized_matmul(quant_type):
     return quantization.quantized_dot(w=self.rhs,
                                       act=self.lhs,
                                       weight_params=weight_params,
                                       act_hparams=act_params,
                                       get_bounds_params=None,
                                       quant_type=quant_type,
                                       prefer_int8_to_int32_dot=True)
 def test_lax_dot_has_integer_inputs_in_quantized_dot(
         self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot,
         prec):
     weight_params = QuantOps.WeightParams(prec=prec,
                                           axis=(0, ),
                                           half_shift=False)
     act_params = QuantOps.ActHParams(input_distribution=act_distribution,
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=prec,
                                      half_shift=False)
     act = self.lhs
     if act_distribution == 'positive':
         act = jnp.abs(act)
     # We need this context manager to stop Jax from trying to compile the arms
     # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always
     # try to compile the functions passed to `lax.cond`, even if outside of a
     # JITed context. JIT compilation is incompatible with using a mock for the
     # call to 'dot_general' because during compilation Jax will expect
     # 'dot_general' to return a tracer and will throw an error if it returns a
     # mock instead. By explicily using jax.disable_jit, Jax will not try to
     # compile the arms to lax.cond and so using a mock will work fine.
     with jax.disable_jit():
         quantization.quantized_dot(
             w=self.rhs,
             act=act,
             weight_params=weight_params,
             act_hparams=act_params,
             get_bounds_params=None,
             quant_type=QuantType.aqt,
             prefer_int8_to_int32_dot=prefer_int8_to_int32_dot)
     act_inputs, weight_inputs = mock_dot_general.call_args[0]
     self.assert_is_integer_in_range(act_inputs,
                                     prec=prec,
                                     distribution=act_distribution)
     self.assert_is_integer_in_range(weight_inputs,
                                     prec=prec,
                                     distribution='symmetric')
     if prefer_int8_to_int32_dot and not (act_distribution == 'positive'
                                          and prec == 8):
         expected_input_dtype = jnp.int8
     else:
         expected_input_dtype = jnp.float32
     self.assertEqual(act_inputs.dtype, expected_input_dtype)
     self.assertEqual(weight_inputs.dtype, expected_input_dtype)
  def test_quantized_dot_general_should_call_weights_and_inputs_quantization(
      self,
      mock_act_fq,
      mock_w_fq,
      weight_prec,
      act_prec,
      strategy=QuantType.fake_quant):
    mock_w_fq.side_effect = lambda inputs, **_: inputs
    mock_act_fq.side_effect = lambda inputs, **_: inputs

    weight_params = QuantOps.WeightParams(
        prec=weight_prec, axis=None, half_shift=False)
    act_hparams = QuantOps.ActHParams(  # pylint: disable=g-long-ternary
        bounds=6.,
        prec=act_prec,
        input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
        half_shift=False) if act_prec else None
    get_bounds_params = GetBounds.Params(
        update_stats=False, update_bounds=False)

    quantization.quantized_dot(
        w=self.weight,
        act=self.act,
        quant_type=strategy,
        weight_params=weight_params,
        act_hparams=act_hparams,
        get_bounds_params=get_bounds_params,
        prefer_int8_to_int32_dot=True)

    quantized_type = strategy.to_jax_type()

    mock_w_fq.assert_called_with(
        mock.ANY,
        weight_params=weight_params,
        quantized_type=quantized_type,
        fake_dependency=mock.ANY)
    if act_hparams:
      mock_act_fq.assert_called_with(
          mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)
    else:
      mock_act_fq.assert_not_called()
 def test_quantized_dot_no_quant(self):
     act_hparams = QuantOps.ActHParams(input_distribution='symmetric',
                                       bounds=-1.0,
                                       prec=4)
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act = jnp.array([[-5.0]])
     w = jnp.array([[-4.99]])
     res = quantization.quantized_dot(w=w,
                                      act=act,
                                      quant_type=quantization.QuantType.aqt,
                                      weight_params=weight_params,
                                      act_hparams=act_hparams,
                                      get_bounds_params=None,
                                      prefer_int8_to_int32_dot=True)
     onp.testing.assert_allclose(res, act * w)
 def test_quantized_dot_has_correct_dtype(self, input_dtype, act_prec,
                                          quant_type):
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=act_prec)
     act = self.lhs.astype(input_dtype)
     w = self.rhs.astype(input_dtype)
     output = quantization.quantized_dot(w=w,
                                         act=act,
                                         weight_params=weight_params,
                                         act_hparams=act_params,
                                         get_bounds_params=None,
                                         quant_type=quant_type,
                                         prefer_int8_to_int32_dot=True)
     self.assertEqual(output.dtype, input_dtype)
    def __call__(
        self,
        inputs,
        *,
        padding_mask,
    ):
        """Applies a linear transformation to the inputs with optional quantization.

    If weight_prec is not None, scales and quantizes weights to signed int with
    weight_prec bits.

    Args:
      inputs: The nd-array to be transformed.
      padding_mask: boolean tensor of the same shape as 'inputs' specifying
        which values of 'inputs' to use as part of the bounds calculation.
        'True' indicates the corresponding value from 'inputs' should be used.
        If None, all values are used.

    Returns:
      The transformed input.
    """
        batch_size, channel_size = inputs.shape  # pylint: disable=unused-variable
        if padding_mask is not None:
            shape_utils.assert_shapes_equal(padding_mask.shape,
                                            (batch_size, 1))
        # TODO(wanglisa): Replace fake quant with AQT.

        if self.quant_context.collect_acts_stats:
            stats_tag.StatsTag(channel_axis=-1,
                               name='inputs',
                               update_stats=self.train)(inputs,
                                                        mask=padding_mask)
        hparams = self.hparams
        if (hparams.weight_prec is not None
                and isinstance(hparams.weight_prec, int)
                and hparams.weight_prec > 8):
            raise NotImplementedError(
                'If you want to use more than 8bits for quantization, please revisit '
                'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.'
            )

        jax_precision = jax.lax.Precision.DEFAULT
        kernel = self.param('kernel', self.kernel_init,
                            (inputs.shape[-1], self.features))

        inputs = jnp.asarray(inputs, self.dtype)
        kernel = jnp.asarray(kernel, self.dtype)

        get_bounds_params = get_bounds.GetBounds.Params(
            update_bounds=self.quant_context.update_bounds,
            update_stats=self.train,
            paxis_name=self.paxis_name,
            mask=padding_mask)

        weight_quant_granularity = hparams.weight_quant_granularity
        # kernel.shape = (channels_in, channels_out)
        if weight_quant_granularity == quant_config.QuantGranularity.per_channel:
            # Compute scale factors by reducing over the rows of the weight matrix,
            # resulting in one scale factor per column. This results in one scale
            # factor per output channel.
            expected_scale_shape = (1, self.features)
            weight_quant_axis = (0, )
        elif weight_quant_granularity == quant_config.QuantGranularity.per_tensor:
            # Compute a single scale factor for the entire weight matrix.
            expected_scale_shape = (1, 1)
            weight_quant_axis = None
        else:
            raise ValueError(
                f'Invalid quantization granularity {weight_quant_granularity}.'
            )

        weight_params = QuantOps.WeightParams(
            prec=hparams.weight_prec,
            axis=weight_quant_axis,
            expected_scale_shape=expected_scale_shape)

        # TODO(wanglisa): add option to control when scale is being recomputed

        # matmul
        y = quantization.quantized_dot(act=inputs,
                                       w=kernel,
                                       quant_type=hparams.quant_type,
                                       weight_params=weight_params,
                                       act_hparams=hparams.quant_act,
                                       get_bounds_params=get_bounds_params,
                                       dot_precision=jax_precision,
                                       prefer_int8_to_int32_dot=self.
                                       quant_context.prefer_int8_to_int32_dot)

        # bias
        if self.use_bias:
            bias = self.param('bias', self.bias_init, (self.features, ))
            # (batch_size, features)
            y = y + bias[jnp.newaxis, :]
        shape_utils.assert_shapes_equal(y.shape, (batch_size, self.features))
        return y