def test_quantized_dynamic_dot_general(self, lhs_prec, rhs_prec):
        lhs_bounds = 2.0
        rhs_bounds = 1.5
        lhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                         bounds=lhs_bounds,
                                         prec=lhs_prec)
        rhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                         bounds=rhs_bounds,
                                         prec=rhs_prec)

        def quantized_matmul(quant_type):
            return quantization.quantized_dynamic_dot_general(
                lhs_act=self.lhs,
                rhs_act=self.rhs,
                lhs_act_hparams=lhs_params,
                rhs_act_hparams=rhs_params,
                lhs_get_bounds_params=None,
                rhs_get_bounds_params=None,
                dot_dimension_numbers=(((1, ), (0, )), ((), ())),
                quant_type=quant_type)

        aqt_result = quantized_matmul(QuantType.aqt)
        fakequant_result = quantized_matmul(QuantType.fake_quant)
        onp.testing.assert_allclose(
            aqt_result,
            fakequant_result,
            rtol=1e-2,
            err_msg='AQT and fakequant significantly disagree')
Пример #2
0
 def __call__(self, lhs_act, rhs_act, lhs_prec, rhs_prec):
     get_bounds_hyper = get_bounds.GetBounds.Hyper(
         initial_bound=10.0,
         stddev_coeff=0,
         absdev_coeff=0,
         mix_coeff=0,
         granularity=quant_config.QuantGranularity.per_tensor)
     lhs_act_hparams = QuantOps.ActHParams(
         input_distribution='symmetric',
         bounds=get_bounds_hyper,
         prec=lhs_prec,
         half_shift=False)
     rhs_act_hparams = QuantOps.ActHParams(
         input_distribution='symmetric',
         bounds=get_bounds_hyper,
         prec=rhs_prec,
         half_shift=False)
     lhs_get_bounds_params = get_bounds.GetBounds.Params(
         update_stats=False, update_bounds=False, module_name='lhs')
     rhs_get_bounds_params = get_bounds.GetBounds.Params(
         update_stats=False, update_bounds=False, module_name='rhs')
     output = quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_act_hparams,
         rhs_act_hparams=rhs_act_hparams,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=QuantType.aqt,
         lhs_get_bounds_params=lhs_get_bounds_params,
         rhs_get_bounds_params=rhs_get_bounds_params)
     return output
    def test_quantized_dot_aqt(self, act_bounds, weight_prec, weight_axis):
        # With a high enough precision, we expect results from fakequant and AQT to
        # be very similar.
        weight_params = QuantOps.WeightParams(prec=weight_prec,
                                              axis=weight_axis)

        if act_bounds is None:
            act_params = None
        else:
            act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                             bounds=jnp.array(act_bounds),
                                             prec=16)

        def quantized_matmul(quant_type):
            return quantization.quantized_dot(w=self.rhs,
                                              act=self.lhs,
                                              weight_params=weight_params,
                                              act_hparams=act_params,
                                              get_bounds_params=None,
                                              quant_type=quant_type,
                                              prefer_int8_to_int32_dot=True)

        aqt_result = quantized_matmul(QuantType.aqt)
        fakequant_result = quantized_matmul(QuantType.fake_quant)
        onp.testing.assert_allclose(
            aqt_result,
            fakequant_result,
            rtol=1e-2,
            err_msg='AQT and fakequant significantly disagree')
 def test_lax_dot_has_integer_inputs_in_dynamic_dot_general(
         self, mock_dot_general, lhs_distribution, rhs_distribution):
     lhs_params = QuantOps.ActHParams(input_distribution=lhs_distribution,
                                      bounds=2.0,
                                      prec=4)
     rhs_params = QuantOps.ActHParams(input_distribution=rhs_distribution,
                                      bounds=1.5,
                                      prec=4)
     lhs_act = self.lhs
     if lhs_distribution == 'positive':
         lhs_act = jnp.abs(lhs_act)
     rhs_act = self.rhs
     if rhs_distribution == 'positive':
         rhs_act = jnp.abs(rhs_act)
     quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=QuantType.aqt)
     lhs_inputs, rhs_inputs = mock_dot_general.call_args[0]
     self.assert_is_integer_in_range(lhs_inputs,
                                     prec=4,
                                     distribution=lhs_distribution)
     self.assert_is_integer_in_range(rhs_inputs,
                                     prec=4,
                                     distribution=rhs_distribution)
    def test_scale_invariance_signed_activation_quantization(self, prec):
        # Scaling activation by power of 2 and bounds by same factor,
        # should scale the output by the same scale.
        activations = random.uniform(random.PRNGKey(0), (10, 1))
        act_scale = 8.
        scaled_activations = activations * act_scale

        bounds = 6.

        activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=prec))

        scaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=scaled_activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds * act_scale,
                                        prec=prec))
        onp.testing.assert_array_equal(activations * act_scale,
                                       scaled_activations)
Пример #6
0
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
                                              fp_spec=quant_hparams.prec)
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,
                                                 bounds=None)

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
            a_exp,
            axis=norm_dims,
            keepdims=True,
            prec=quant_hparams.reduction_prec)
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,
                                         dtype=dtype)

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
                                             dtype=dtype)
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)
Пример #7
0
        def quantized_layernorm(x):
            prec = hparams.quant_hparams.prec
            fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec)
            quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant,
                                                     bounds=None)

            def to_quantized(x):
                return quant_ops.to_quantized(x, dtype=dtype)

            # If epsilon is too small to represent in the quantized format, we set it
            # to the minimal representative non-zero value to avoid the possibility of
            # dividing by zero.
            fp_bounds = quantization.fp_cast.get_bounds(
                prec.exp_min, prec.exp_max, prec.sig_bits)
            epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound)
            quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype))

            # If the reciprocal of the quantized number of features is too small to
            # represent in the quantized format, we set it to the minimal
            # representative nonzero value so that the mean and variance are not
            # trivially 0.
            num_features_quantized = to_quantized(
                jnp.array(num_features, dtype=dtype))
            num_features_recip_quantized = to_quantized(
                jnp.reciprocal(num_features_quantized))
            num_features_recip_quantized = jax.lax.cond(
                jax.lax.eq(num_features_recip_quantized,
                           0.0), lambda _: quantized_epsilon,
                lambda _: num_features_recip_quantized, None)

            x_quantized = to_quantized(x)
            x_sum_quantized_reduction = quantization.quantized_sum(
                x_quantized,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sum = to_quantized(x_sum_quantized_reduction)
            mean = to_quantized(x_sum * num_features_recip_quantized)
            x_minus_mean = to_quantized(x - mean)
            x_sq = to_quantized(lax.square(x_minus_mean))
            x_sq_sum_quantized_reduction = quantization.quantized_sum(
                x_sq,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sq_sum = to_quantized(x_sq_sum_quantized_reduction)
            var = to_quantized(x_sq_sum * num_features_recip_quantized)
            # Prevent division by zero.
            var_plus_epsilon = to_quantized(var + quantized_epsilon)
            mul = to_quantized(lax.rsqrt(var_plus_epsilon))
            if self.use_scale:
                quantized_scale_param = to_quantized(scale_param)
                mul = to_quantized(mul * quantized_scale_param)
            y = to_quantized(x_minus_mean * mul)
            if self.use_bias:
                quantized_bias_param = to_quantized(bias_param)
                y = to_quantized(y + quantized_bias_param)
            return y.astype(self.dtype)
 def test_float_weights_quantization(self, prec):
     # Tests that quantized and rescaled float weights are close to original
     # weights.
     weights = jnp.array(
         fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 1))))
     rescaled_weights = QuantOps.create_weights_fake_quant(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=prec, axis=None))
     test_utils.assert_all_close_prec(weights, rescaled_weights, prec=prec)
Пример #9
0
  def test_attributes_create_acts_op_fp(
      self,
      act_distribution,
      use_hparams_bounds,
  ):
    inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4))))
    fp_quant = QuantOps.FloatQuant(
        is_scaled=True,
        fp_spec=QuantOps.FloatQuant.FloatPrec(
            exp_min=-15,
            exp_max=15,
            sig_bits=2,
        ),
    )
    if use_hparams_bounds:
      bounds = get_bounds.GetBounds.Hyper(
          initial_bound=6.0,
          stddev_coeff=1,
          absdev_coeff=0,
          mix_coeff=1,
          reset_stats=True,
          ema_coeff=None,
          use_cams=False,
          granularity=quant_config.QuantGranularity.per_tensor)
    else:
      bounds = 6.0

    hparams = QuantOps.ActHParams(
        input_distribution=act_distribution, bounds=bounds, prec=fp_quant,
        half_shift=False)

    class TestModule(nn.Module):
      hparams: QuantOps.ActHParams

      @nn.compact
      def __call__(self, inputs):
        return QuantOps.create_input_ops(
            inputs,
            hparams=hparams,
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False))

    test_module = TestModule(hparams=hparams)
    state = test_module.init(jax.random.PRNGKey(0), inputs=inputs)
    act_quant_op = test_module.apply(state, inputs=inputs)

    act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype)
    act_quant_expected = fp_cast.downcast_sat_ftz(
        act_scaled,
        fp_quant.fp_spec.exp_min,
        fp_quant.fp_spec.exp_max,
        fp_quant.fp_spec.sig_bits,
    )
    act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE)
    onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated)
    def test_weight_scale_shape_is_expected(self, axis):
        # Tests if scale is as expected for weights quantization.

        num_features = 4
        expected_scale_shape = (1, 1) if axis is None else (1, num_features)

        # Weight Quantization
        weights = jnp.array(
            fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, num_features))))
        _ = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(
                prec=8.0, axis=axis,
                expected_scale_shape=expected_scale_shape))
    def test_per_feature_dim_scale_invariance_weight_quantization(self, prec):
        # Scaling each channel of weights by a different power of 2, should scale
        # the respective channel of output by the same scale.
        weights = random.uniform(random.PRNGKey(0), (3, 4))
        weight_scale = 2**jnp.arange(4)[jnp.newaxis, :]
        scaled_weights = weights * weight_scale

        weights = quantization.QuantOps.create_weights_fake_quant(
            w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=0))

        scaled_weights = quantization.QuantOps.create_weights_fake_quant(
            w=scaled_weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=0))

        onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
 def test_quantized_dot_raises_with_mixed_dtype(self, quant_type):
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=4)
     act = self.lhs.astype(jnp.bfloat16)
     w = self.rhs.astype(jnp.float32)
     with self.assertRaises(TypeError):
         quantization.quantized_dot(w=w,
                                    act=act,
                                    weight_params=weight_params,
                                    act_hparams=act_params,
                                    get_bounds_params=None,
                                    quant_type=quant_type,
                                    prefer_int8_to_int32_dot=True)
    def test_scale_invariance_weight_quantization(self, prec):
        # Scaling weights by power of 2, should scale the output by the same scale.
        weights = random.uniform(random.PRNGKey(0), (10, 1))
        weight_scale = 16
        scaled_weights = weights * weight_scale

        weights = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))

        scaled_weights = QuantOps.create_weights_fake_quant(
            w=scaled_weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))

        onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)
 def test_quantized_dot_no_quant(self):
     act_hparams = QuantOps.ActHParams(input_distribution='symmetric',
                                       bounds=-1.0,
                                       prec=4)
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act = jnp.array([[-5.0]])
     w = jnp.array([[-4.99]])
     res = quantization.quantized_dot(w=w,
                                      act=act,
                                      quant_type=quantization.QuantType.aqt,
                                      weight_params=weight_params,
                                      act_hparams=act_hparams,
                                      get_bounds_params=None,
                                      prefer_int8_to_int32_dot=True)
     onp.testing.assert_allclose(res, act * w)
    def test_int_positive_act_quantization(self, prec):
        # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec
        # quantizes correctly.
        upper_bound = 2**(prec - 3)
        activations = random.randint(random.PRNGKey(0), (10, 1), 0,
                                     upper_bound)

        rescaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.positive,
                                        bounds=upper_bound,
                                        prec=prec))
        onp.testing.assert_array_equal(activations, rescaled_activations)
 def test_attributes_create_weights_ops(self, weight_range, weight_shape,
                                        prec):
     weights = jnp.array(
         fp32(
             onp.random.uniform(weight_range[0],
                                weight_range[1],
                                size=weight_shape)))
     axis = 0 if weight_shape[1] != 1 else None
     weights_quant = QuantOps.create_weights_ops(
         w=weights,
         weight_params=QuantOps.WeightParams(prec=prec, axis=axis))
     max_weight = onp.max(abs(weights), axis=0)
     onp.testing.assert_array_equal(jnp.squeeze(weights_quant._scale),
                                    (2**(prec - 1) - 1) / max_weight)
     self.assertEqual(weights_quant._symmetric, True)
     self.assertEqual(weights_quant._prec, prec)
Пример #17
0
 def __call__(self, inputs):
   return QuantOps.create_input_ops(
       inputs,
       hparams=hparams,
       get_bounds_params=GetBounds.Params(
           update_stats=False,
           update_bounds=False))
    def test_int_symmetric_act_quantization(self, prec):
        # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1
        # quantizes correctly.
        bounds = 2**(prec - 1) - 1
        activations = random.randint(random.PRNGKey(0), (10, 1), -bounds,
                                     bounds)
        rescaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=prec))

        onp.testing.assert_array_equal(activations, rescaled_activations)
 def test_attributes_create_symmetric(self, bounds, prec):
     bounds = jnp.array(bounds)
     act_signed = QuantOps.create_symmetric(bounds=bounds, prec=prec)
     onp.testing.assert_array_equal(act_signed._scale,
                                    (2**(prec - 1) - 1) / bounds)
     self.assertEqual(act_signed._symmetric, True)
     self.assertEqual(act_signed._prec, prec)
 def test_quantized_dot_has_correct_dtype(self, input_dtype, act_prec,
                                          quant_type):
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=act_prec)
     act = self.lhs.astype(input_dtype)
     w = self.rhs.astype(input_dtype)
     output = quantization.quantized_dot(w=w,
                                         act=act,
                                         weight_params=weight_params,
                                         act_hparams=act_params,
                                         get_bounds_params=None,
                                         quant_type=quant_type,
                                         prefer_int8_to_int32_dot=True)
     self.assertEqual(output.dtype, input_dtype)
    def test_quantized_dynamic_dot_general_should_call_inputs_quantization(
            self,
            mock_act_fq,
            lhs_act_prec,
            rhs_act_prec,
            strategy=QuantType.fake_quant):
        mock_act_fq.side_effect = lambda inputs, hparams, get_bounds_params: inputs

        # pylint: disable=g-long-ternary
        lhs_act_hparams = QuantOps.ActHParams(
            bounds=6.,
            prec=lhs_act_prec,
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            half_shift=False) if lhs_act_prec else None
        rhs_act_hparams = QuantOps.ActHParams(
            bounds=6.,
            prec=rhs_act_prec,
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            half_shift=False) if rhs_act_prec else None
        # pylint: enable=g-long-ternary

        get_bounds_params = GetBounds.Params(update_stats=False,
                                             update_bounds=False)

        quantization.quantized_dynamic_dot_general(
            lhs_act=self.lhs_act,
            rhs_act=self.rhs_act,
            quant_type=strategy,
            dot_dimension_numbers=self.dimension_numbers,
            lhs_act_hparams=lhs_act_hparams,
            lhs_get_bounds_params=get_bounds_params,
            rhs_act_hparams=rhs_act_hparams,
            rhs_get_bounds_params=get_bounds_params,
        )
        calls = []
        for prec in [lhs_act_prec, rhs_act_prec]:
            if prec is not None:
                act_hparams = QuantOps.ActHParams(bounds=6.,
                                                  prec=prec,
                                                  input_distribution=mock.ANY,
                                                  half_shift=False)
                calls.append(
                    mock.call(mock.ANY,
                              hparams=act_hparams,
                              get_bounds_params=get_bounds_params))
        self.assertLen(calls, mock_act_fq.call_count)
        mock_act_fq.assert_has_calls(calls, any_order=True)
Пример #22
0
 def test_full_range_int_weight_quantization(self, prec):
   # Integer weights in full range [-maxmin_signed_int, maxmin_signed_int]
   # quantizes correctly.
   minval = -2**(prec - 1) + 1
   maxval = 2**(prec - 1) - 1
   weights = random.randint(random.PRNGKey(0), (10, 1), minval, maxval + 1)
   weights = weights.at[0, :].set(maxval)
   weight_quant = QuantOps.create_weights_ops(
       w=weights,
       weight_params=QuantOps.WeightParams(
           prec=prec, axis=None, half_shift=False))
   quantized_weights = weight_quant.to_quantized(weights, dtype=SCALE_DTYPE)
   onp.testing.assert_array_equal(quantized_weights[0],
                                  (2**(prec - 1.0) - 1.0))
   rescaled_weights = weight_quant.from_quantized(
       quantized_weights, dtype=jnp.float32)
   onp.testing.assert_array_equal(weights, rescaled_weights)
Пример #23
0
 def test_dynamic_quantized_dot_general_raises_with_mixed_dtype(self):
   lhs_params = QuantOps.ActHParams(
       input_distribution='symmetric', bounds=2.0, prec=4, half_shift=False)
   rhs_params = QuantOps.ActHParams(
       input_distribution='symmetric', bounds=1.5, prec=4, half_shift=False)
   lhs_act = self.lhs.astype(jnp.bfloat16)
   rhs_act = self.rhs.astype(jnp.float32)
   with self.assertRaises(TypeError):
     quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1,), (0,)), ((), ())),
         quant_type=QuantType.aqt)
    def test_inputs_scale_shape_is_expected(self):
        # Inputs quantization
        inputs = jnp.array(fp32(2.0 *
                                onp.random.uniform(0, 1.0, size=(10, 4))))
        bounds = 6.0
        expected_inputs_scale_shape = ()

        _ = QuantOps.create_inputs_fake_quant(
            inputs=inputs,
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=8.0),
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False,
                expected_bounds_shape=expected_inputs_scale_shape))
 def test_lax_dot_has_integer_inputs_in_quantized_dot(
         self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot,
         prec):
     weight_params = QuantOps.WeightParams(prec=prec,
                                           axis=(0, ),
                                           half_shift=False)
     act_params = QuantOps.ActHParams(input_distribution=act_distribution,
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=prec,
                                      half_shift=False)
     act = self.lhs
     if act_distribution == 'positive':
         act = jnp.abs(act)
     # We need this context manager to stop Jax from trying to compile the arms
     # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always
     # try to compile the functions passed to `lax.cond`, even if outside of a
     # JITed context. JIT compilation is incompatible with using a mock for the
     # call to 'dot_general' because during compilation Jax will expect
     # 'dot_general' to return a tracer and will throw an error if it returns a
     # mock instead. By explicily using jax.disable_jit, Jax will not try to
     # compile the arms to lax.cond and so using a mock will work fine.
     with jax.disable_jit():
         quantization.quantized_dot(
             w=self.rhs,
             act=act,
             weight_params=weight_params,
             act_hparams=act_params,
             get_bounds_params=None,
             quant_type=QuantType.aqt,
             prefer_int8_to_int32_dot=prefer_int8_to_int32_dot)
     act_inputs, weight_inputs = mock_dot_general.call_args[0]
     self.assert_is_integer_in_range(act_inputs,
                                     prec=prec,
                                     distribution=act_distribution)
     self.assert_is_integer_in_range(weight_inputs,
                                     prec=prec,
                                     distribution='symmetric')
     if prefer_int8_to_int32_dot and not (act_distribution == 'positive'
                                          and prec == 8):
         expected_input_dtype = jnp.int8
     else:
         expected_input_dtype = jnp.float32
     self.assertEqual(act_inputs.dtype, expected_input_dtype)
     self.assertEqual(weight_inputs.dtype, expected_input_dtype)
Пример #26
0
 def test_positive_activation_quantization_clips_outside_bounds(self, prec):
   # Activation values less than 0 get clipped to 0, and values greater than
   # upper_bound get clipped to upper_bound
   relu6 = QuantOps.create_positive(bounds=6.0, prec=prec)
   activation = jnp.array(fp32([-0.5, 6.2, 3.141]))
   quantized_activations = relu6.to_quantized(activation, dtype=SCALE_DTYPE)
   onp.testing.assert_array_equal(quantized_activations[0:2],
                                  [0.0, 2**prec - 1])
   activations = relu6.from_quantized(quantized_activations, dtype=jnp.float32)
   max_clipped_val = (2**prec - 1) * (6.0 / 2**prec)
   onp.testing.assert_array_equal(activations[0:2], [0.0, max_clipped_val])
 def test_dynamic_quantized_dot_general_has_correct_dtype(
         self, input_dtype, act_prec, quant_type):
     lhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=2.0,
                                      prec=act_prec)
     rhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=1.5,
                                      prec=act_prec)
     lhs_act = self.lhs.astype(input_dtype)
     rhs_act = self.rhs.astype(input_dtype)
     output = quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=quant_type)
     self.assertEqual(output.dtype, input_dtype)
Пример #28
0
  def test_quantized_dot_general_should_call_weights_and_inputs_quantization(
      self,
      mock_act_fq,
      mock_w_fq,
      weight_prec,
      act_prec,
      strategy=QuantType.fake_quant):
    mock_w_fq.side_effect = lambda inputs, **_: inputs
    mock_act_fq.side_effect = lambda inputs, **_: inputs

    weight_params = QuantOps.WeightParams(
        prec=weight_prec, axis=None, half_shift=False)
    act_hparams = QuantOps.ActHParams(  # pylint: disable=g-long-ternary
        bounds=6.,
        prec=act_prec,
        input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
        half_shift=False) if act_prec else None
    get_bounds_params = GetBounds.Params(
        update_stats=False, update_bounds=False)

    quantization.quantized_dot(
        w=self.weight,
        act=self.act,
        quant_type=strategy,
        weight_params=weight_params,
        act_hparams=act_hparams,
        get_bounds_params=get_bounds_params,
        prefer_int8_to_int32_dot=True)

    quantized_type = strategy.to_jax_type()

    mock_w_fq.assert_called_with(
        mock.ANY,
        weight_params=weight_params,
        quantized_type=quantized_type,
        fake_dependency=mock.ANY)
    if act_hparams:
      mock_act_fq.assert_called_with(
          mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)
    else:
      mock_act_fq.assert_not_called()
    def test_per_feature_dim_scale_invariance_pos_activation_quantization(
            self, prec):
        # Scaling each channel of activations by a different power of 2 and upper
        # bound with same scale, should scale the respective channel of output by
        # the same scale.
        activations = random.uniform(random.PRNGKey(0), (3, 4))
        act_scale = 2**jnp.arange(4)
        scaled_activations = activations * act_scale[jnp.newaxis, :]

        upper_bound = 6.0 * jnp.ones((3, 4), jnp.float32)

        act_quant_ops = QuantOps.create_positive(bounds=upper_bound, prec=prec)
        activations = act_quant_ops.fake_quant(activations,
                                               quantized_type=SCALE_DTYPE)

        scaled_act_quant_ops = QuantOps.create_positive(
            bounds=upper_bound * act_scale[jnp.newaxis, :], prec=prec)
        scaled_activations = scaled_act_quant_ops.fake_quant(
            scaled_activations, quantized_type=SCALE_DTYPE)
        onp.testing.assert_array_equal(activations * act_scale[jnp.newaxis, :],
                                       scaled_activations)
Пример #30
0
  def test_quantized_dot_general_aqt(self, act_bounds, weight_prec,
                                     weight_axis):
    # With a high enough precision, we expect results from fakequant and AQT to
    # be very similar.
    weight_params = QuantOps.WeightParams(
        prec=weight_prec, axis=weight_axis, half_shift=False)

    if act_bounds is None:
      act_params = None
    else:
      act_params = QuantOps.ActHParams(
          input_distribution='symmetric',
          bounds=jnp.array(act_bounds),
          prec=16,
          half_shift=False)

    lhs_ndims_3 = jnp.array(
        fp32(2.0 * onp.random.uniform(0, 1.0, size=(4, 3, 2))))

    def quantized_matmul(quant_type):
      return quantization.quantized_dot_general(
          w=self.rhs,
          act=lhs_ndims_3,
          weight_params=weight_params,
          act_hparams=act_params,
          get_bounds_params=None,
          quant_type=quant_type,
          dimension_numbers=(((lhs_ndims_3.ndim - 1,), (0,)), ((), ())),
          prefer_int8_to_int32_dot=True)

    aqt_result = quantized_matmul(QuantType.aqt)
    self.assertEqual(aqt_result.shape, (4, 3, 4))

    fakequant_result = quantized_matmul(QuantType.fake_quant)
    onp.testing.assert_allclose(
        aqt_result,
        fakequant_result,
        rtol=1e-2,
        err_msg='AQT and fakequant significantly disagree')