def test_per_feature_dim_scale_invariance_weight_quantization(self, prec):
        # Scaling each channel of weights by a different power of 2, should scale
        # the respective channel of output by the same scale.
        weights = random.uniform(random.PRNGKey(0), (3, 4))
        weight_scale = 2**jnp.arange(4)[jnp.newaxis, :]
        scaled_weights = weights * weight_scale

        weights = quantization.QuantOps.create_weights_fake_quant(
            w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=0))

        scaled_weights = quantization.QuantOps.create_weights_fake_quant(
            w=scaled_weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=0))

        onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
    def test_scale_invariance_weight_quantization(self, prec):
        # Scaling weights by power of 2, should scale the output by the same scale.
        weights = random.uniform(random.PRNGKey(0), (10, 1))
        weight_scale = 16
        scaled_weights = weights * weight_scale

        weights = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))

        scaled_weights = QuantOps.create_weights_fake_quant(
            w=scaled_weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))

        onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
    def test_quantized_dot_aqt(self, act_bounds, weight_prec, weight_axis):
        # With a high enough precision, we expect results from fakequant and AQT to
        # be very similar.
        weight_params = QuantOps.WeightParams(prec=weight_prec,
                                              axis=weight_axis)

        if act_bounds is None:
            act_params = None
        else:
            act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                             bounds=jnp.array(act_bounds),
                                             prec=16)

        def quantized_matmul(quant_type):
            return quantization.quantized_dot(w=self.rhs,
                                              act=self.lhs,
                                              weight_params=weight_params,
                                              act_hparams=act_params,
                                              get_bounds_params=None,
                                              quant_type=quant_type,
                                              prefer_int8_to_int32_dot=True)

        aqt_result = quantized_matmul(QuantType.aqt)
        fakequant_result = quantized_matmul(QuantType.fake_quant)
        onp.testing.assert_allclose(
            aqt_result,
            fakequant_result,
            rtol=1e-2,
            err_msg='AQT and fakequant significantly disagree')
 def test_float_weights_quantization(self, prec):
     # Tests that quantized and rescaled float weights are close to original
     # weights.
     weights = jnp.array(
         fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 1))))
     rescaled_weights = QuantOps.create_weights_fake_quant(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=prec, axis=None))
     test_utils.assert_all_close_prec(weights, rescaled_weights, prec=prec)
    def test_weight_scale_shape_is_expected(self, axis):
        # Tests if scale is as expected for weights quantization.

        num_features = 4
        expected_scale_shape = (1, 1) if axis is None else (1, num_features)

        # Weight Quantization
        weights = jnp.array(
            fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, num_features))))
        _ = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(
                prec=8.0, axis=axis,
                expected_scale_shape=expected_scale_shape))
 def test_quantized_dot_no_quant(self):
     act_hparams = QuantOps.ActHParams(input_distribution='symmetric',
                                       bounds=-1.0,
                                       prec=4)
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act = jnp.array([[-5.0]])
     w = jnp.array([[-4.99]])
     res = quantization.quantized_dot(w=w,
                                      act=act,
                                      quant_type=quantization.QuantType.aqt,
                                      weight_params=weight_params,
                                      act_hparams=act_hparams,
                                      get_bounds_params=None,
                                      prefer_int8_to_int32_dot=True)
     onp.testing.assert_allclose(res, act * w)
 def test_quantized_dot_raises_with_mixed_dtype(self, quant_type):
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=4)
     act = self.lhs.astype(jnp.bfloat16)
     w = self.rhs.astype(jnp.float32)
     with self.assertRaises(TypeError):
         quantization.quantized_dot(w=w,
                                    act=act,
                                    weight_params=weight_params,
                                    act_hparams=act_params,
                                    get_bounds_params=None,
                                    quant_type=quant_type,
                                    prefer_int8_to_int32_dot=True)
 def test_quantized_dot_has_correct_dtype(self, input_dtype, act_prec,
                                          quant_type):
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=act_prec)
     act = self.lhs.astype(input_dtype)
     w = self.rhs.astype(input_dtype)
     output = quantization.quantized_dot(w=w,
                                         act=act,
                                         weight_params=weight_params,
                                         act_hparams=act_params,
                                         get_bounds_params=None,
                                         quant_type=quant_type,
                                         prefer_int8_to_int32_dot=True)
     self.assertEqual(output.dtype, input_dtype)
 def test_attributes_create_weights_ops(self, weight_range, weight_shape,
                                        prec):
     weights = jnp.array(
         fp32(
             onp.random.uniform(weight_range[0],
                                weight_range[1],
                                size=weight_shape)))
     axis = 0 if weight_shape[1] != 1 else None
     weights_quant = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=prec, axis=axis))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(jnp.squeeze(weights_quant._scale),
                                    (2**(prec - 1) - 1) / max_weight)
     self.assertEqual(weights_quant._symmetric, True)
     self.assertEqual(weights_quant._prec, prec)
예제 #10
0
 def test_full_range_int_weight_quantization(self, prec):
   # Integer weights in full range [-maxmin_signed_int, maxmin_signed_int]
   # quantizes correctly.
   minval = -2**(prec - 1) + 1
   maxval = 2**(prec - 1) - 1
   weights = random.randint(random.PRNGKey(0), (10, 1), minval, maxval + 1)
   weights = weights.at[0, :].set(maxval)
   weight_quant = QuantOps.create_weights_ops(
       w=weights,
       weight_params=QuantOps.WeightParams(
           prec=prec, axis=None, half_shift=False))
   quantized_weights = weight_quant.to_quantized(weights, dtype=SCALE_DTYPE)
   onp.testing.assert_array_equal(quantized_weights[0],
                                  (2**(prec - 1.0) - 1.0))
   rescaled_weights = weight_quant.from_quantized(
       quantized_weights, dtype=jnp.float32)
   onp.testing.assert_array_equal(weights, rescaled_weights)
 def test_lax_dot_has_integer_inputs_in_quantized_dot(
         self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot,
         prec):
     weight_params = QuantOps.WeightParams(prec=prec,
                                           axis=(0, ),
                                           half_shift=False)
     act_params = QuantOps.ActHParams(input_distribution=act_distribution,
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=prec,
                                      half_shift=False)
     act = self.lhs
     if act_distribution == 'positive':
         act = jnp.abs(act)
     # We need this context manager to stop Jax from trying to compile the arms
     # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always
     # try to compile the functions passed to `lax.cond`, even if outside of a
     # JITed context. JIT compilation is incompatible with using a mock for the
     # call to 'dot_general' because during compilation Jax will expect
     # 'dot_general' to return a tracer and will throw an error if it returns a
     # mock instead. By explicily using jax.disable_jit, Jax will not try to
     # compile the arms to lax.cond and so using a mock will work fine.
     with jax.disable_jit():
         quantization.quantized_dot(
             w=self.rhs,
             act=act,
             weight_params=weight_params,
             act_hparams=act_params,
             get_bounds_params=None,
             quant_type=QuantType.aqt,
             prefer_int8_to_int32_dot=prefer_int8_to_int32_dot)
     act_inputs, weight_inputs = mock_dot_general.call_args[0]
     self.assert_is_integer_in_range(act_inputs,
                                     prec=prec,
                                     distribution=act_distribution)
     self.assert_is_integer_in_range(weight_inputs,
                                     prec=prec,
                                     distribution='symmetric')
     if prefer_int8_to_int32_dot and not (act_distribution == 'positive'
                                          and prec == 8):
         expected_input_dtype = jnp.int8
     else:
         expected_input_dtype = jnp.float32
     self.assertEqual(act_inputs.dtype, expected_input_dtype)
     self.assertEqual(weight_inputs.dtype, expected_input_dtype)
예제 #12
0
 def setup(self):
     self.embedding = self.param(
         'embedding',
         self.embedding_init,  # pylint: disable=missing-from-attributes
         (self.num_embeddings, self.features))
     hparams = self.hparams
     if hparams.quant_act is not None and isinstance(
             hparams.quant_act.bounds, get_bounds.GetBounds.Hyper):
         self.get_bounds_logits = get_bounds.GetBounds(  # pylint: disable=missing-from-attributes
             hyper=self.hparams.quant_act.bounds)
     self.quantized_dot = quantization.QuantizedDot(  # pylint: disable=missing-from-attributes
         act_hparams=hparams.quant_act,
         quant_type=hparams.quant_type,
         dot_precision=None,
         prefer_int8_to_int32_dot=self.quant_context.
         prefer_int8_to_int32_dot,
         weight_params=QuantOps.WeightParams(
             prec=hparams.weight_prec,
             axis=(0, ),
             expected_scale_shape=(1, self.embedding.shape[0])))
예제 #13
0
  def test_quantized_dot_general_should_call_weights_and_inputs_quantization(
      self,
      mock_act_fq,
      mock_w_fq,
      weight_prec,
      act_prec,
      strategy=QuantType.fake_quant):
    mock_w_fq.side_effect = lambda inputs, **_: inputs
    mock_act_fq.side_effect = lambda inputs, **_: inputs

    weight_params = QuantOps.WeightParams(
        prec=weight_prec, axis=None, half_shift=False)
    act_hparams = QuantOps.ActHParams(  # pylint: disable=g-long-ternary
        bounds=6.,
        prec=act_prec,
        input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
        half_shift=False) if act_prec else None
    get_bounds_params = GetBounds.Params(
        update_stats=False, update_bounds=False)

    quantization.quantized_dot(
        w=self.weight,
        act=self.act,
        quant_type=strategy,
        weight_params=weight_params,
        act_hparams=act_hparams,
        get_bounds_params=get_bounds_params,
        prefer_int8_to_int32_dot=True)

    quantized_type = strategy.to_jax_type()

    mock_w_fq.assert_called_with(
        mock.ANY,
        weight_params=weight_params,
        quantized_type=quantized_type,
        fake_dependency=mock.ANY)
    if act_hparams:
      mock_act_fq.assert_called_with(
          mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)
    else:
      mock_act_fq.assert_not_called()
예제 #14
0
  def test_quantized_dot_general_aqt(self, act_bounds, weight_prec,
                                     weight_axis):
    # With a high enough precision, we expect results from fakequant and AQT to
    # be very similar.
    weight_params = QuantOps.WeightParams(
        prec=weight_prec, axis=weight_axis, half_shift=False)

    if act_bounds is None:
      act_params = None
    else:
      act_params = QuantOps.ActHParams(
          input_distribution='symmetric',
          bounds=jnp.array(act_bounds),
          prec=16,
          half_shift=False)

    lhs_ndims_3 = jnp.array(
        fp32(2.0 * onp.random.uniform(0, 1.0, size=(4, 3, 2))))

    def quantized_matmul(quant_type):
      return quantization.quantized_dot_general(
          w=self.rhs,
          act=lhs_ndims_3,
          weight_params=weight_params,
          act_hparams=act_params,
          get_bounds_params=None,
          quant_type=quant_type,
          dimension_numbers=(((lhs_ndims_3.ndim - 1,), (0,)), ((), ())),
          prefer_int8_to_int32_dot=True)

    aqt_result = quantized_matmul(QuantType.aqt)
    self.assertEqual(aqt_result.shape, (4, 3, 4))

    fakequant_result = quantized_matmul(QuantType.fake_quant)
    onp.testing.assert_allclose(
        aqt_result,
        fakequant_result,
        rtol=1e-2,
        err_msg='AQT and fakequant significantly disagree')
 def test_attributes_create_weights_op_fp(
     self,
     weight_range,
     weight_shape,
     fp_quant,
 ):
     weights = jnp.array(
         fp32(onp.random.uniform(*weight_range, size=weight_shape)))
     axis = None if weight_shape[1] == 1 else 0
     weights_quant_op = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=fp_quant,
                                             axis=axis,
                                             half_shift=False))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(
         jnp.squeeze(weights_quant_op._scale),
         jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
     self.assertEqual(weights_quant_op._symmetric, True)
     self.assertIs(weights_quant_op._prec, fp_quant)
     weights_scaled = (weights * weights_quant_op._scale).astype(
         weights.dtype)
     weights_quant_expected = fp_cast.downcast_sat_ftz(
         weights_scaled,
         fp_quant.fp_spec.exp_min,
         fp_quant.fp_spec.exp_max,
         fp_quant.fp_spec.sig_bits,
     )
     weights_quant_calculated = weights_quant_op.to_quantized(
         weights, dtype=SCALE_DTYPE)
     onp.testing.assert_array_equal(weights_quant_expected,
                                    weights_quant_calculated)
     # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
     # quantized weights are zero.
     sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
     onp.testing.assert_array_equal(
         weights_quant_calculated.view(jnp.int32) & sig_mask,
         jnp.zeros_like(weights))
예제 #16
0
    def __call__(self, inputs):
        """Applies a convolution to the inputs with optional quantization.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).

    Returns:
      The convolved data.
    """
        hparams = self.hparams
        if hparams.weight_prec is not None and hparams.weight_prec > 8:
            raise NotImplementedError(
                'If you want to use more than 8bits for quantization, please revisit '
                'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.'
            )
        jax_precision = jax.lax.Precision.DEFAULT

        if self.strides is None:
            strides = (1, ) * (inputs.ndim - 2)
        else:
            strides = self.strides

        in_features = inputs.shape[-1]
        assert in_features % self.feature_group_count == 0
        kernel_shape = self.kernel_size + (
            in_features // self.feature_group_count, self.features)
        kernel = self.param('kernel', self.kernel_init, kernel_shape)

        inputs = jnp.asarray(inputs, self.dtype)
        kernel = jnp.asarray(kernel, self.dtype)

        # Activation quantization
        if hparams.quant_act is not None:
            inputs = QuantOps.create_inputs_fake_quant(
                inputs=inputs,
                hparams=hparams.quant_act,
                get_bounds_params=get_bounds.GetBounds.Params(
                    update_bounds=self.quant_context.update_bounds,
                    update_stats=self.train,
                    paxis_name=self.paxis_name))

        # Weight quantization
        if hparams.weight_prec is not None:
            kernel_reduction_axis = tuple(range(kernel.ndim - 1))
            expected_scale_shape = (1, ) * (kernel.ndim - 1) + (
                self.features, )
            assert hparams.quant_type == QuantType.fake_quant, (
                'we only support fake_quant style of aqt for ConvAqt.')
            quantized_type = hparams.quant_type.to_jax_type()
            kernel = QuantOps.create_weights_fake_quant(
                kernel,
                weight_params=QuantOps.WeightParams(
                    prec=hparams.weight_prec,
                    half_shift=hparams.weight_half_shift,
                    axis=kernel_reduction_axis,
                    expected_scale_shape=expected_scale_shape),
                quantized_type=quantized_type)

        # Convolution
        dimension_numbers = flax.nn.linear._conv_dimension_numbers(
            inputs.shape)  # pylint: disable=protected-access
        metadata_context = contextlib.suppress()
        # Use metadata context to annotate op metadata with quantization info
        act_prec = None if hparams.quant_act is None else hparams.quant_act.prec

        if flags.FLAGS.metadata_enabled:
            metadata_context = compute_cost_utils.ConvMetadataMonkeyPatch(
                weight_prec=hparams.weight_prec, act_prec=act_prec)
        with metadata_context:
            y = lax.conv_general_dilated(
                inputs,
                kernel,
                strides,
                self.padding,
                lhs_dilation=self.input_dilation,
                rhs_dilation=self.kernel_dilation,
                dimension_numbers=dimension_numbers,
                feature_group_count=self.feature_group_count,
                precision=jax_precision)
        # TODO(shivaniagrawal): create quantized conv general dilated.

        # bias
        if self.use_bias:
            bias = self.param('bias', self.bias_init, (self.features, ))
            bias = jnp.asarray(bias, self.dtype)
            # The inputs can have an arbitrary number of spatial dims, so we broadcast
            # the bias to match: (batch_size, spatial_dim,... features)
            # TODO(shivaniagrawal): Consider making ConvAqt rank static (e.g. 2D)
            # or maybe add error checking (e.g. expect inputs to have rank N, but this
            # may already be checked by lax.conv_general_dialated).
            bias = utils.broadcast_rank(bias, inputs)
            y = y + bias
        return y
예제 #17
0
    def __call__(
        self,
        inputs,
    ):
        """Embeds the inputs along the last dimension.

    Args:
      inputs: input data, all dimensions are considered batch dimensions.

    Returns:
      Output which is embedded input data.  The output shape follows the input,
      with an additional `features` dimension appended.
    """
        batch_size, sequence_length = inputs.shape
        if inputs.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]:
            raise ValueError(
                'Input type must be an integer or unsigned integer.')
        embedding = self.embedding

        embedding = jnp.asarray(embedding, self.dtype)

        hparams = self.hparams
        # Initialize state for stats and bounds, this would be required for logits
        # in the following method attend.
        if hparams.quant_act is not None and isinstance(
                hparams.quant_act.bounds, get_bounds.GetBounds.Hyper):
            self.get_bounds_logits(
                inputs,
                bounds_params=get_bounds.GetBounds.Params(update_stats=False,
                                                          update_bounds=False,
                                                          paxis_name=None),
            )

        weight_prec = hparams.weight_prec
        weight_half_shift = hparams.weight_half_shift
        if weight_prec is not None:
            quantized_type = hparams.quant_type.to_jax_type()
            # In contrast to all other scale factor calculations in this module, we
            # compute per-row instead of per-column (ie, per-output-channel) scale
            # factors here. This is because the embedding matrix might be shared with
            # the output (logit) layer of the transformer, in which case the
            # *transpose* of the embedding matrix will be used as the weight matrix in
            # a mamtul. The per-row scale factors used here would thus correspond to
            # using per-column (because of the transpose) scale factors used by the
            # weight matrix in the logits layer, which is what we need for AQT.
            embedding_quant_ops = QuantOps.create_weights_ops(
                embedding,
                weight_params=QuantOps.WeightParams(
                    prec=weight_prec, axis=(1, ),
                    half_shift=weight_half_shift))
            embedding_quant_ops.assert_scale_shape_is(
                shape=(self.num_embeddings, 1))

            quantized_embedding = embedding_quant_ops.to_quantized(
                embedding, dtype=quantized_type)
            quantized_embedded_inputs = quantized_embedding[inputs]
            # Since the embedding matrix 'quantized_embedding' is gathered based on
            # 'inputs' to produce the embedding tensor, we apply the same gathering to
            # the per-row scale factors of the embedding matrix so the scale factors
            # will broadcast appropriately in the subsequent call to 'to_quantized'.
            # TODO(malmaud): As part of quantization.py refactor, change
            # 'get_scale_for_aqt' to cleanly support this and hence avoid the need to
            # directly access a protected member of QuantOps.
            scale = embedding_quant_ops._scale[inputs]  # pylint: disable=protected-access
            shape_utils.assert_shapes_equal(scale.shape,
                                            (batch_size, sequence_length, 1))
            shape_utils.assert_shapes_equal(
                quantized_embedded_inputs.shape,
                (batch_size, sequence_length, self.features))
            embedded_inputs = (quantized_embedded_inputs / scale).astype(
                self.dtype)
        else:
            embedded_inputs = embedding[inputs]
        shape_utils.assert_shapes_equal(
            embedded_inputs.shape,
            (batch_size, sequence_length, self.features))
        return embedded_inputs
예제 #18
0
    def __call__(
        self,
        inputs,
        *,
        padding_mask,
    ):
        """Applies a linear transformation to the inputs with optional quantization.

    If weight_prec is not None, scales and quantizes weights to signed int with
    weight_prec bits.

    Args:
      inputs: The nd-array to be transformed.
      padding_mask: boolean tensor of the same shape as 'inputs' specifying
        which values of 'inputs' to use as part of the bounds calculation.
        'True' indicates the corresponding value from 'inputs' should be used.
        If None, all values are used.

    Returns:
      The transformed input.
    """
        batch_size = inputs.shape[0]
        if padding_mask is not None:
            shape_utils.assert_shapes_equal(padding_mask.shape,
                                            (batch_size, 1))
        # TODO(wanglisa): Replace fake quant with AQT.

        if self.quant_context.collect_acts_stats:
            stats_tag.StatsTag(channel_axis=-1,
                               name='inputs',
                               update_stats=self.train)(inputs,
                                                        mask=padding_mask)
        hparams = self.hparams
        if (hparams.weight_prec is not None
                and isinstance(hparams.weight_prec, int)
                and hparams.weight_prec > 8):
            raise NotImplementedError(
                'If you want to use more than 8bits for quantization, please revisit '
                'jax.lax.Precision.DEFAULT to determine whether it is still sufficient.'
            )

        kernel = self.param('kernel', self.kernel_init,
                            (inputs.shape[-1], self.features))

        inputs = jnp.asarray(inputs, self.dtype)
        kernel = jnp.asarray(kernel, self.dtype)

        get_bounds_params = get_bounds.GetBounds.Params(
            update_bounds=self.quant_context.update_bounds,
            update_stats=self.train,
            paxis_name=self.paxis_name,
            mask=padding_mask)

        weight_quant_granularity = hparams.weight_quant_granularity
        # kernel.shape = (channels_in, channels_out)
        if weight_quant_granularity == quant_config.QuantGranularity.per_channel:
            # Compute scale factors by reducing over the rows of the weight matrix,
            # resulting in one scale factor per column. This results in one scale
            # factor per output channel.
            expected_scale_shape = (1, self.features)
            weight_quant_axis = (0, )
        elif weight_quant_granularity == quant_config.QuantGranularity.per_tensor:
            # Compute a single scale factor for the entire weight matrix.
            expected_scale_shape = (1, 1)
            weight_quant_axis = None
        else:
            raise ValueError(
                f'Invalid quantization granularity {weight_quant_granularity}.'
            )

        weight_params = QuantOps.WeightParams(
            prec=hparams.weight_prec,
            half_shift=hparams.weight_half_shift,
            axis=weight_quant_axis,
            expected_scale_shape=expected_scale_shape)

        # TODO(wanglisa): add option to control when scale is being recomputed

        # matmul
        contracting_dims = ((inputs.ndim - 1, ), (0, ))
        # `((lhs_contracting_dims, rhs_contracting_dims),
        batch_dims = ((), ())  # (lhs_batch_dims, rhs_batch_dims))`
        y = quantization.quantized_dot_general(
            act=inputs,
            w=kernel,
            quant_type=hparams.quant_type,
            weight_params=weight_params,
            act_hparams=hparams.quant_act,
            get_bounds_params=get_bounds_params,
            dimension_numbers=(contracting_dims, batch_dims),
            dot_precision=self.precision,
            prefer_int8_to_int32_dot=self.quant_context.
            prefer_int8_to_int32_dot)

        # bias
        if self.use_bias:
            bias = self.param('bias', self.bias_init, (self.features, ))
            # (batch_size, features)
            y = y + bias[jnp.newaxis, :]
        return y