Пример #1
0
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
                                              fp_spec=quant_hparams.prec)
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,
                                                 bounds=None)

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
            a_exp,
            axis=norm_dims,
            keepdims=True,
            prec=quant_hparams.reduction_prec)
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,
                                         dtype=dtype)

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
                                             dtype=dtype)
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)
Пример #2
0
        def quantized_layernorm(x):
            prec = hparams.quant_hparams.prec
            fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec)
            quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant,
                                                     bounds=None)

            def to_quantized(x):
                return quant_ops.to_quantized(x, dtype=dtype)

            # If epsilon is too small to represent in the quantized format, we set it
            # to the minimal representative non-zero value to avoid the possibility of
            # dividing by zero.
            fp_bounds = quantization.fp_cast.get_bounds(
                prec.exp_min, prec.exp_max, prec.sig_bits)
            epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound)
            quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype))

            # If the reciprocal of the quantized number of features is too small to
            # represent in the quantized format, we set it to the minimal
            # representative nonzero value so that the mean and variance are not
            # trivially 0.
            num_features_quantized = to_quantized(
                jnp.array(num_features, dtype=dtype))
            num_features_recip_quantized = to_quantized(
                jnp.reciprocal(num_features_quantized))
            num_features_recip_quantized = jax.lax.cond(
                jax.lax.eq(num_features_recip_quantized,
                           0.0), lambda _: quantized_epsilon,
                lambda _: num_features_recip_quantized, None)

            x_quantized = to_quantized(x)
            x_sum_quantized_reduction = quantization.quantized_sum(
                x_quantized,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sum = to_quantized(x_sum_quantized_reduction)
            mean = to_quantized(x_sum * num_features_recip_quantized)
            x_minus_mean = to_quantized(x - mean)
            x_sq = to_quantized(lax.square(x_minus_mean))
            x_sq_sum_quantized_reduction = quantization.quantized_sum(
                x_sq,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sq_sum = to_quantized(x_sq_sum_quantized_reduction)
            var = to_quantized(x_sq_sum * num_features_recip_quantized)
            # Prevent division by zero.
            var_plus_epsilon = to_quantized(var + quantized_epsilon)
            mul = to_quantized(lax.rsqrt(var_plus_epsilon))
            if self.use_scale:
                quantized_scale_param = to_quantized(scale_param)
                mul = to_quantized(mul * quantized_scale_param)
            y = to_quantized(x_minus_mean * mul)
            if self.use_bias:
                quantized_bias_param = to_quantized(bias_param)
                y = to_quantized(y + quantized_bias_param)
            return y.astype(self.dtype)
Пример #3
0
  def test_attributes_create_acts_op_fp(
      self,
      act_distribution,
      use_hparams_bounds,
  ):
    inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4))))
    fp_quant = QuantOps.FloatQuant(
        is_scaled=True,
        fp_spec=QuantOps.FloatQuant.FloatPrec(
            exp_min=-15,
            exp_max=15,
            sig_bits=2,
        ),
    )
    if use_hparams_bounds:
      bounds = get_bounds.GetBounds.Hyper(
          initial_bound=6.0,
          stddev_coeff=1,
          absdev_coeff=0,
          mix_coeff=1,
          reset_stats=True,
          ema_coeff=None,
          use_cams=False,
          granularity=quant_config.QuantGranularity.per_tensor)
    else:
      bounds = 6.0

    hparams = QuantOps.ActHParams(
        input_distribution=act_distribution, bounds=bounds, prec=fp_quant,
        half_shift=False)

    class TestModule(nn.Module):
      hparams: QuantOps.ActHParams

      @nn.compact
      def __call__(self, inputs):
        return QuantOps.create_input_ops(
            inputs,
            hparams=hparams,
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False))

    test_module = TestModule(hparams=hparams)
    state = test_module.init(jax.random.PRNGKey(0), inputs=inputs)
    act_quant_op = test_module.apply(state, inputs=inputs)

    act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype)
    act_quant_expected = fp_cast.downcast_sat_ftz(
        act_scaled,
        fp_quant.fp_spec.exp_min,
        fp_quant.fp_spec.exp_max,
        fp_quant.fp_spec.sig_bits,
    )
    act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE)
    onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated)
class QuantOpsTest(parameterized.TestCase):
    def setUp(self):
        super(QuantOpsTest, self).setUp()
        quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = True

    @parameterized.named_parameters(
        dict(testcase_name='prec_2', bounds=6.0, prec=2),
        dict(testcase_name='prec_4', bounds=6.0, prec=4),
        dict(testcase_name='prec_8', bounds=6.0, prec=8),
        dict(testcase_name='2_features_prec_8', bounds=[6., 12.], prec=8),
    )
    def test_attributes_create_positive(self, bounds, prec):
        bounds = jnp.array(bounds)
        relu6 = QuantOps.create_positive(bounds=bounds, prec=prec)
        onp.testing.assert_array_equal(relu6._scale, 2**prec / bounds)
        self.assertEqual(relu6._symmetric, False)
        self.assertEqual(relu6._prec, prec)

    @parameterized.named_parameters(
        dict(testcase_name='prec_2', bounds=6.0, prec=2),
        dict(testcase_name='prec_4', bounds=6.0, prec=4),
        dict(testcase_name='prec_8', bounds=6.0, prec=8),
        dict(testcase_name='2_features_prec_8', bounds=[6., 12.], prec=8),
    )
    def test_attributes_create_symmetric(self, bounds, prec):
        bounds = jnp.array(bounds)
        act_signed = QuantOps.create_symmetric(bounds=bounds, prec=prec)
        onp.testing.assert_array_equal(act_signed._scale,
                                       (2**(prec - 1) - 1) / bounds)
        self.assertEqual(act_signed._symmetric, True)
        self.assertEqual(act_signed._prec, prec)

    @parameterized.named_parameters(
        dict(
            testcase_name='fp8_143',
            weight_range=[2.0, 64.0],
            weight_shape=(10, 1),
            fp_quant=QuantOps.FloatQuant(
                is_scaled=True,
                fp_spec=QuantOps.FloatQuant.FloatPrec(
                    exp_min=-11,
                    exp_max=4,
                    sig_bits=3,
                ),
            ),
        ),
        dict(
            testcase_name='fp8_152',
            weight_range=[2.0, 64.0],
            weight_shape=(10, 1),
            fp_quant=QuantOps.FloatQuant(
                is_scaled=True,
                fp_spec=QuantOps.FloatQuant.FloatPrec(
                    exp_min=-23,
                    exp_max=8,
                    sig_bits=2,
                ),
            ),
        ),
    )
    def test_attributes_create_weights_op_fp(
        self,
        weight_range,
        weight_shape,
        fp_quant,
    ):
        weights = jnp.array(
            fp32(onp.random.uniform(*weight_range, size=weight_shape)))
        axis = None if weight_shape[1] == 1 else 0
        weights_quant_op = QuantOps.create_weights_ops(
            w=weights,
            weight_params=QuantOps.WeightParams(prec=fp_quant, axis=axis))
        max_weight = onp.max(abs(weights), axis=0)
        onp.testing.assert_array_equal(
            jnp.squeeze(weights_quant_op._scale),
            jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
        self.assertEqual(weights_quant_op._symmetric, True)
        self.assertIs(weights_quant_op._prec, fp_quant)
        weights_scaled = (weights * weights_quant_op._scale).astype(
            weights.dtype)
        weights_quant_expected = fp_cast.downcast_sat_ftz(
            weights_scaled,
            fp_quant.fp_spec.exp_min,
            fp_quant.fp_spec.exp_max,
            fp_quant.fp_spec.sig_bits,
        )
        weights_quant_calculated = weights_quant_op.to_quantized(
            weights, dtype=SCALE_DTYPE)
        onp.testing.assert_array_equal(weights_quant_expected,
                                       weights_quant_calculated)
        # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
        # quantized weights are zero.
        sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
        onp.testing.assert_array_equal(
            weights_quant_calculated.view(jnp.int32) & sig_mask,
            jnp.zeros_like(weights))

    @parameterized.named_parameters(
        dict(testcase_name='pos_weight_prec_2',
             weight_range=[2.0, 10.0],
             weight_shape=(10, 1),
             prec=2),
        dict(testcase_name='pos_weight_prec_4',
             weight_range=[2.0, 10.0],
             weight_shape=(10, 1),
             prec=4),
        dict(testcase_name='pos_weight_prec_8',
             weight_range=[2.0, 10.0],
             weight_shape=(10, 1),
             prec=8),
        dict(testcase_name='neg_weight_prec_8',
             weight_range=[-12.0, 2.0],
             weight_shape=(10, 1),
             prec=8),
        dict(testcase_name='neg_weight_2_features_prec_8',
             weight_range=[-12.0, 2.0],
             weight_shape=(10, 2),
             prec=8),
    )
    def test_attributes_create_weights_ops(self, weight_range, weight_shape,
                                           prec):
        weights = jnp.array(
            fp32(
                onp.random.uniform(weight_range[0],
                                   weight_range[1],
                                   size=weight_shape)))
        axis = 0 if weight_shape[1] != 1 else None
        weights_quant = QuantOps.create_weights_ops(
            w=weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=axis))
        max_weight = onp.max(abs(weights), axis=0)
        onp.testing.assert_array_equal(jnp.squeeze(weights_quant._scale),
                                       (2**(prec - 1) - 1) / max_weight)
        self.assertEqual(weights_quant._symmetric, True)
        self.assertEqual(weights_quant._prec, prec)

    @parameterized.named_parameters(
        dict(testcase_name='per_layer_quant', axis=None),
        dict(testcase_name='per_channel_quant', axis=(0, )))
    def test_weight_scale_shape_is_expected(self, axis):
        # Tests if scale is as expected for weights quantization.

        num_features = 4
        expected_scale_shape = (1, 1) if axis is None else (1, num_features)

        # Weight Quantization
        weights = jnp.array(
            fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, num_features))))
        _ = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(
                prec=8.0, axis=axis,
                expected_scale_shape=expected_scale_shape))

    def test_inputs_scale_shape_is_expected(self):
        # Inputs quantization
        inputs = jnp.array(fp32(2.0 *
                                onp.random.uniform(0, 1.0, size=(10, 4))))
        bounds = 6.0
        expected_inputs_scale_shape = ()

        _ = QuantOps.create_inputs_fake_quant(
            inputs=inputs,
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=8.0),
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False,
                expected_bounds_shape=expected_inputs_scale_shape))

    @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2),
                                    dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_positive_activation_quantization_clips_outside_bounds(self, prec):
        # Activation values less than 0 get clipped to 0, and values greater than
        # upper_bound get clipped to upper_bound
        relu6 = QuantOps.create_positive(bounds=6.0, prec=prec)
        activation = jnp.array(fp32([-0.5, 6.2, 3.141]))
        quantized_activations = relu6.to_quantized(activation,
                                                   dtype=SCALE_DTYPE)
        onp.testing.assert_array_equal(quantized_activations[0:2],
                                       [0.0, 2**prec - 1])
        activations = relu6.from_quantized(quantized_activations,
                                           dtype=jnp.float32)
        max_clipped_val = (2**prec - 1) * (6.0 / 2**prec)
        onp.testing.assert_array_equal(activations[0:2],
                                       [0.0, max_clipped_val])

    @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2),
                                    dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_per_feature_dim_unsigned_activation_quantization_clips_outside_bounds(
            self, prec):
        # Activation values less than -upper_bound get clipped to -upper_bound, and
        # values greater than upper_bound get clipped to upper_bound
        act_quant = QuantOps.create_symmetric(bounds=jnp.array([[6.0, 8.0]]),
                                              prec=prec)
        activation = jnp.array(fp32([[-7, -8.9], [6.2, 9.4], [0, 0.]]))
        quantized_activations = act_quant.to_quantized(activation,
                                                       dtype=SCALE_DTYPE)
        onp.testing.assert_array_equal(
            quantized_activations,
            jnp.array([[-2**(prec - 1.0) + 1.0], [2**(prec - 1.0) - 1.0],
                       [0.0]]) * jnp.array([[1., 1.]]))
        activations = act_quant.from_quantized(quantized_activations,
                                               dtype=jnp.float32)
        onp.testing.assert_array_equal(activations,
                                       [[-6.0, -8.0], [6.0, 8.], [0, 0.]])

    @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2),
                                    dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_scale_invariance_signed_activation_quantization(self, prec):
        # Scaling activation by power of 2 and bounds by same factor,
        # should scale the output by the same scale.
        activations = random.uniform(random.PRNGKey(0), (10, 1))
        act_scale = 8.
        scaled_activations = activations * act_scale

        bounds = 6.

        activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=prec))

        scaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=scaled_activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds * act_scale,
                                        prec=prec))
        onp.testing.assert_array_equal(activations * act_scale,
                                       scaled_activations)

    @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2),
                                    dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_per_feature_dim_scale_invariance_pos_activation_quantization(
            self, prec):
        # Scaling each channel of activations by a different power of 2 and upper
        # bound with same scale, should scale the respective channel of output by
        # the same scale.
        activations = random.uniform(random.PRNGKey(0), (3, 4))
        act_scale = 2**jnp.arange(4)
        scaled_activations = activations * act_scale[jnp.newaxis, :]

        upper_bound = 6.0 * jnp.ones((3, 4), jnp.float32)

        act_quant_ops = QuantOps.create_positive(bounds=upper_bound, prec=prec)
        activations = act_quant_ops.fake_quant(activations,
                                               quantized_type=SCALE_DTYPE)

        scaled_act_quant_ops = QuantOps.create_positive(
            bounds=upper_bound * act_scale[jnp.newaxis, :], prec=prec)
        scaled_activations = scaled_act_quant_ops.fake_quant(
            scaled_activations, quantized_type=SCALE_DTYPE)
        onp.testing.assert_array_equal(activations * act_scale[jnp.newaxis, :],
                                       scaled_activations)

    @parameterized.named_parameters(dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_int_positive_act_quantization(self, prec):
        # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec
        # quantizes correctly.
        upper_bound = 2**(prec - 3)
        activations = random.randint(random.PRNGKey(0), (10, 1), 0,
                                     upper_bound)

        rescaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.positive,
                                        bounds=upper_bound,
                                        prec=prec))
        onp.testing.assert_array_equal(activations, rescaled_activations)

    @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2),
                                    dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_int_symmetric_act_quantization(self, prec):
        # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1
        # quantizes correctly.
        bounds = 2**(prec - 1) - 1
        activations = random.randint(random.PRNGKey(0), (10, 1), -bounds,
                                     bounds)
        rescaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=prec))

        onp.testing.assert_array_equal(activations, rescaled_activations)

    @parameterized.named_parameters(dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_float_weights_quantization(self, prec):
        # Tests that quantized and rescaled float weights are close to original
        # weights.
        weights = jnp.array(
            fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 1))))
        rescaled_weights = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))
        test_utils.assert_all_close_prec(weights, rescaled_weights, prec=prec)

    @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2),
                                    dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_full_range_int_weight_quantization(self, prec):
        # Integer weights in full range [-maxmin_signed_int, maxmin_signed_int]
        # quantizes correctly.
        minval = -2**(prec - 1) + 1
        maxval = 2**(prec - 1) - 1
        weights = random.randint(random.PRNGKey(0), (10, 1), minval,
                                 maxval + 1)
        weights = jax.ops.index_update(weights, jax.ops.index[0, :], maxval)
        weight_quant = QuantOps.create_weights_ops(
            w=weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))
        quantized_weights = weight_quant.to_quantized(weights,
                                                      dtype=SCALE_DTYPE)
        onp.testing.assert_array_equal(quantized_weights[0],
                                       (2**(prec - 1.0) - 1.0))
        rescaled_weights = weight_quant.from_quantized(quantized_weights,
                                                       dtype=jnp.float32)
        onp.testing.assert_array_equal(weights, rescaled_weights)

    @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2),
                                    dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_scale_invariance_weight_quantization(self, prec):
        # Scaling weights by power of 2, should scale the output by the same scale.
        weights = random.uniform(random.PRNGKey(0), (10, 1))
        weight_scale = 16
        scaled_weights = weights * weight_scale

        weights = QuantOps.create_weights_fake_quant(
            w=weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))

        scaled_weights = QuantOps.create_weights_fake_quant(
            w=scaled_weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=None))

        onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)

    @parameterized.named_parameters(dict(testcase_name='prec_2', prec=2),
                                    dict(testcase_name='prec_4', prec=4),
                                    dict(testcase_name='prec_8', prec=8))
    def test_per_feature_dim_scale_invariance_weight_quantization(self, prec):
        # Scaling each channel of weights by a different power of 2, should scale
        # the respective channel of output by the same scale.
        weights = random.uniform(random.PRNGKey(0), (3, 4))
        weight_scale = 2**jnp.arange(4)[jnp.newaxis, :]
        scaled_weights = weights * weight_scale

        weights = quantization.QuantOps.create_weights_fake_quant(
            w=weights, weight_params=QuantOps.WeightParams(prec=prec, axis=0))

        scaled_weights = quantization.QuantOps.create_weights_fake_quant(
            w=scaled_weights,
            weight_params=QuantOps.WeightParams(prec=prec, axis=0))

        onp.testing.assert_array_equal(weights * weight_scale, scaled_weights)

    def test_no_quantization(self):
        # If initial_bound==-1 when using GetBounds, then create_inputs_fake_quant
        # should be a no-op.
        inputs = jnp.array([[.3, 1.4], [-5.2, 4.0]])
        bounds = get_bounds.GetBounds.Hyper(
            initial_bound=-1,
            stddev_coeff=1,
            absdev_coeff=0,
            mix_coeff=1,
            reset_stats=True,
            ema_coeff=None,
            use_cams=False,
            granularity=quant_config.QuantGranularity.per_tensor)
        hparams = quantization.QuantOps.ActHParams(
            input_distribution='symmetric', bounds=bounds, prec=4)

        # The call to create_inputs_fake_quant has to occur from within a Flax
        # module since it calls GetBounds, which is itself a Flax module.
        # Thus we create a wrapper module for testing.
        class TestModule(nn.Module):

            hparams: quantization.QuantOps.ActHParams

            @nn.compact
            def __call__(self, inputs):
                return quantization.QuantOps.create_inputs_fake_quant(
                    inputs,
                    hparams=hparams,
                    get_bounds_params=GetBounds.Params(update_stats=True,
                                                       update_bounds=False))

        test_module = TestModule(hparams=hparams)
        state = test_module.init(jax.random.PRNGKey(0), inputs=inputs)
        inputs_after_fake_quant, _ = test_module.apply(state,
                                                       inputs=inputs,
                                                       mutable=True)
        onp.testing.assert_array_equal(inputs, inputs_after_fake_quant)
Пример #5
0
class QuantOpsTest(parameterized.TestCase):

  def setUp(self):
    super(QuantOpsTest, self).setUp()
    quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = True

  @parameterized.named_parameters(
      dict(testcase_name='prec_2', bounds=6.0, prec=2),
      dict(testcase_name='prec_4', bounds=6.0, prec=4),
      dict(testcase_name='prec_8', bounds=6.0, prec=8),
      dict(
          testcase_name='2_features_prec_8',
          bounds=[6., 12.],
          prec=8),
  )
  def test_attributes_create_positive(self, bounds, prec):
    bounds = jnp.array(bounds)
    relu6 = QuantOps.create_positive(bounds=bounds, prec=prec)
    onp.testing.assert_array_equal(relu6._scale, 2**prec / bounds)
    self.assertEqual(relu6._symmetric, False)
    self.assertEqual(relu6._prec, prec)

  @parameterized.named_parameters(
      dict(testcase_name='prec_2', bounds=6.0, prec=2),
      dict(testcase_name='prec_4', bounds=6.0, prec=4),
      dict(testcase_name='prec_8', bounds=6.0, prec=8),
      dict(
          testcase_name='2_features_prec_8',
          bounds=[6., 12.],
          prec=8),
  )
  def test_attributes_create_symmetric(self, bounds, prec):
    bounds = jnp.array(bounds)
    act_signed = QuantOps.create_symmetric(
        bounds=bounds, prec=prec, half_shift=False)
    onp.testing.assert_array_equal(act_signed._scale,
                                   (2**(prec - 1) - 1) / bounds)
    self.assertEqual(act_signed._symmetric, True)
    self.assertEqual(act_signed._prec, prec)

  @parameterized.named_parameters(
      dict(
          testcase_name='fp8_143',
          weight_range=[2.0, 64.0],
          weight_shape=(10, 1),
          fp_quant=QuantOps.FloatQuant(
              is_scaled=True,
              fp_spec=QuantOps.FloatQuant.FloatPrec(
                  exp_min=-11,
                  exp_max=4,
                  sig_bits=3,
              ),
          ),
      ),
      dict(
          testcase_name='fp8_152',
          weight_range=[2.0, 64.0],
          weight_shape=(10, 1),
          fp_quant=QuantOps.FloatQuant(
              is_scaled=True,
              fp_spec=QuantOps.FloatQuant.FloatPrec(
                  exp_min=-23,
                  exp_max=8,
                  sig_bits=2,
              ),
          ),
      ),
  )
  def test_attributes_create_weights_op_fp(
      self,
      weight_range,
      weight_shape,
      fp_quant,
  ):
    weights = jnp.array(
        fp32(onp.random.uniform(*weight_range, size=weight_shape)))
    axis = None if weight_shape[1] == 1 else 0
    weights_quant_op = QuantOps.create_weights_ops(
        w=weights,
        weight_params=QuantOps.WeightParams(
            prec=fp_quant, axis=axis, half_shift=False))
    max_weight = onp.max(abs(weights), axis=0)
    onp.testing.assert_array_equal(
        jnp.squeeze(weights_quant_op._scale),
        jnp.exp2(-jnp.floor(jnp.log2(max_weight))))
    self.assertEqual(weights_quant_op._symmetric, True)
    self.assertIs(weights_quant_op._prec, fp_quant)
    weights_scaled = (weights * weights_quant_op._scale).astype(weights.dtype)
    weights_quant_expected = fp_cast.downcast_sat_ftz(
        weights_scaled,
        fp_quant.fp_spec.exp_min,
        fp_quant.fp_spec.exp_max,
        fp_quant.fp_spec.sig_bits,
    )
    weights_quant_calculated = weights_quant_op.to_quantized(
        weights, dtype=SCALE_DTYPE)
    onp.testing.assert_array_equal(weights_quant_expected,
                                   weights_quant_calculated)
    # Test the lower (23 - fp_quant.fp_spec.sig_bits) bits of the calculated
    # quantized weights are zero.
    sig_mask = jnp.int32((1 << (23 - fp_quant.fp_spec.sig_bits)) - 1)
    onp.testing.assert_array_equal(
        weights_quant_calculated.view(jnp.int32) & sig_mask,
        jnp.zeros_like(weights))

  @parameterized.named_parameters(
      dict(
          testcase_name='fp_act_symmetric',
          act_distribution='symmetric',
          use_hparams_bounds=False,
      ),
      # TODO(b/193561347): FP quantization with positive input distribution is
      # not supported yet
      dict(
          testcase_name='fp_act_positive',
          act_distribution='positive',
          use_hparams_bounds=False,
      ),
      dict(
          testcase_name='fp_act_symmetric_hyper_bounds',
          act_distribution='symmetric',
          use_hparams_bounds=True,
      ),
      dict(
          testcase_name='fp_act_positive_hyper_bounds',
          act_distribution='positive',
          use_hparams_bounds=True,
      ),
  )
  def test_attributes_create_acts_op_fp(
      self,
      act_distribution,
      use_hparams_bounds,
  ):
    inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4))))
    fp_quant = QuantOps.FloatQuant(
        is_scaled=True,
        fp_spec=QuantOps.FloatQuant.FloatPrec(
            exp_min=-15,
            exp_max=15,
            sig_bits=2,
        ),
    )
    if use_hparams_bounds:
      bounds = get_bounds.GetBounds.Hyper(
          initial_bound=6.0,
          stddev_coeff=1,
          absdev_coeff=0,
          mix_coeff=1,
          reset_stats=True,
          ema_coeff=None,
          use_cams=False,
          granularity=quant_config.QuantGranularity.per_tensor)
    else:
      bounds = 6.0

    hparams = QuantOps.ActHParams(
        input_distribution=act_distribution, bounds=bounds, prec=fp_quant,
        half_shift=False)

    class TestModule(nn.Module):
      hparams: QuantOps.ActHParams

      @nn.compact
      def __call__(self, inputs):
        return QuantOps.create_input_ops(
            inputs,
            hparams=hparams,
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False))

    test_module = TestModule(hparams=hparams)
    state = test_module.init(jax.random.PRNGKey(0), inputs=inputs)
    act_quant_op = test_module.apply(state, inputs=inputs)

    act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype)
    act_quant_expected = fp_cast.downcast_sat_ftz(
        act_scaled,
        fp_quant.fp_spec.exp_min,
        fp_quant.fp_spec.exp_max,
        fp_quant.fp_spec.sig_bits,
    )
    act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE)
    onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated)

  @parameterized.named_parameters(
      dict(
          testcase_name='pos_weight_prec_2',
          weight_range=[2.0, 10.0],
          weight_shape=(10, 1),
          prec=2),
      dict(
          testcase_name='pos_weight_prec_4',
          weight_range=[2.0, 10.0],
          weight_shape=(10, 1),
          prec=4),
      dict(
          testcase_name='pos_weight_prec_8',
          weight_range=[2.0, 10.0],
          weight_shape=(10, 1),
          prec=8),
      dict(
          testcase_name='neg_weight_prec_8',
          weight_range=[-12.0, 2.0],
          weight_shape=(10, 1),
          prec=8),
      dict(
          testcase_name='neg_weight_2_features_prec_8',
          weight_range=[-12.0, 2.0],
          weight_shape=(10, 2),
          prec=8),
  )
  def test_attributes_create_weights_ops(self, weight_range, weight_shape,
                                         prec):
    weights = jnp.array(
        fp32(
            onp.random.uniform(
                weight_range[0], weight_range[1], size=weight_shape)))
    axis = 0 if weight_shape[1] != 1 else None
    weights_quant = QuantOps.create_weights_ops(
        w=weights,
        weight_params=QuantOps.WeightParams(
            prec=prec, axis=axis, half_shift=False))
    max_weight = onp.max(abs(weights), axis=0)
    onp.testing.assert_array_equal(
        jnp.squeeze(weights_quant._scale), (2**(prec - 1) - 1) / max_weight)
    self.assertEqual(weights_quant._symmetric, True)
    self.assertEqual(weights_quant._prec, prec)
Пример #6
0
class ActQuantizationTest(parameterized.TestCase):

  def test_inputs_scale_shape_is_expected(self):
    # Inputs quantization
    inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4))))
    bounds = 6.0
    expected_inputs_scale_shape = ()

    _ = QuantOps.create_inputs_fake_quant(
        inputs=inputs,
        hparams=QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            bounds=bounds,
            prec=8.0,
            half_shift=False),
        get_bounds_params=GetBounds.Params(
            update_stats=False,
            update_bounds=False,
            expected_bounds_shape=expected_inputs_scale_shape))

  @parameterized.named_parameters(
      dict(testcase_name='prec_2',
           prec=2), dict(testcase_name='prec_4', prec=4),
      dict(testcase_name='prec_8', prec=8))
  def test_positive_activation_quantization_clips_outside_bounds(self, prec):
    # Activation values less than 0 get clipped to 0, and values greater than
    # upper_bound get clipped to upper_bound
    relu6 = QuantOps.create_positive(bounds=6.0, prec=prec)
    activation = jnp.array(fp32([-0.5, 6.2, 3.141]))
    quantized_activations = relu6.to_quantized(activation, dtype=SCALE_DTYPE)
    onp.testing.assert_array_equal(quantized_activations[0:2],
                                   [0.0, 2**prec - 1])
    activations = relu6.from_quantized(quantized_activations, dtype=jnp.float32)
    max_clipped_val = (2**prec - 1) * (6.0 / 2**prec)
    onp.testing.assert_array_equal(activations[0:2], [0.0, max_clipped_val])

  @parameterized.named_parameters(
      dict(testcase_name='prec_2', prec=2),
      dict(testcase_name='prec_4', prec=4),
      dict(testcase_name='prec_8', prec=8)
  )
  def test_per_feature_dim_unsigned_activation_quantization_clips_outside_bounds(
      self, prec):
    # Activation values less than -upper_bound get clipped to -upper_bound, and
    # values greater than upper_bound get clipped to upper_bound
    act_quant = QuantOps.create_symmetric(
        bounds=jnp.array([[6.0, 8.0]]), prec=prec, half_shift=False)
    activation = jnp.array(fp32([[-7, -8.9], [6.2, 9.4], [0, 0.]]))
    quantized_activations = act_quant.to_quantized(
        activation, dtype=SCALE_DTYPE)
    onp.testing.assert_array_equal(
        quantized_activations,
        jnp.array([[-2**(prec - 1.0) + 1.0], [2**(prec - 1.0) - 1.0], [0.0]]) *
        jnp.array([[1., 1.]]))
    activations = act_quant.from_quantized(
        quantized_activations, dtype=jnp.float32)
    onp.testing.assert_array_equal(activations,
                                   [[-6.0, -8.0], [6.0, 8.], [0, 0.]])

  @parameterized.named_parameters(
      dict(testcase_name='prec_2', prec=2),
      dict(testcase_name='prec_4', prec=4),
      dict(testcase_name='prec_8', prec=8)
  )
  def test_scale_invariance_signed_activation_quantization(self, prec):
    # Scaling activation by power of 2 and bounds by same factor,
    # should scale the output by the same scale.
    activations = random.uniform(random.PRNGKey(0), (10, 1))
    act_scale = 8.
    scaled_activations = activations * act_scale

    bounds = 6.

    activations = QuantOps.create_inputs_fake_quant(
        inputs=activations,
        get_bounds_params=GetBounds.Params(
            update_stats=False, update_bounds=False),
        hparams=QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            bounds=bounds,
            prec=prec,
            half_shift=False))

    scaled_activations = QuantOps.create_inputs_fake_quant(
        inputs=scaled_activations,
        get_bounds_params=GetBounds.Params(
            update_stats=False, update_bounds=False),
        hparams=QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            bounds=bounds * act_scale,
            prec=prec,
            half_shift=False))
    onp.testing.assert_array_equal(activations * act_scale, scaled_activations)

  @parameterized.named_parameters(
      dict(testcase_name='prec_2', prec=2),
      dict(testcase_name='prec_4', prec=4),
      dict(testcase_name='prec_8', prec=8)
  )
  def test_per_feature_dim_scale_invariance_pos_activation_quantization(
      self, prec):
    # Scaling each channel of activations by a different power of 2 and upper
    # bound with same scale, should scale the respective channel of output by
    # the same scale.
    activations = random.uniform(random.PRNGKey(0), (3, 4))
    act_scale = 2**jnp.arange(4)
    scaled_activations = activations * act_scale[jnp.newaxis, :]

    upper_bound = 6.0 * jnp.ones((3, 4), jnp.float32)

    act_quant_ops = QuantOps.create_positive(bounds=upper_bound, prec=prec)
    activations = act_quant_ops.fake_quant(
        activations, quantized_type=SCALE_DTYPE)

    scaled_act_quant_ops = QuantOps.create_positive(
        bounds=upper_bound * act_scale[jnp.newaxis, :], prec=prec)
    scaled_activations = scaled_act_quant_ops.fake_quant(
        scaled_activations, quantized_type=SCALE_DTYPE)
    onp.testing.assert_array_equal(activations * act_scale[jnp.newaxis, :],
                                   scaled_activations)

  @parameterized.named_parameters(
      dict(testcase_name='prec_4', prec=4),
      dict(testcase_name='prec_8', prec=8))
  def test_int_positive_act_quantization(self, prec):
    # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec
    # quantizes correctly.
    upper_bound = 2**(prec - 3)
    activations = random.randint(random.PRNGKey(0), (10, 1), 0, upper_bound)

    rescaled_activations = QuantOps.create_inputs_fake_quant(
        inputs=activations,
        get_bounds_params=GetBounds.Params(
            update_stats=False, update_bounds=False),
        hparams=QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.positive,
            bounds=upper_bound,
            prec=prec,
            half_shift=False))
    onp.testing.assert_array_equal(activations, rescaled_activations)

  @parameterized.named_parameters(
      dict(testcase_name='prec_2', prec=2),
      dict(testcase_name='prec_4', prec=4),
      dict(testcase_name='prec_8', prec=8)
  )
  def test_int_symmetric_act_quantization(self, prec):
    # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1
    # quantizes correctly.
    bounds = 2**(prec - 1) - 1
    activations = random.randint(random.PRNGKey(0), (10, 1), -bounds, bounds)
    rescaled_activations = QuantOps.create_inputs_fake_quant(
        inputs=activations,
        get_bounds_params=GetBounds.Params(
            update_stats=False, update_bounds=False),
        hparams=QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            bounds=bounds,
            prec=prec,
            half_shift=False))

    onp.testing.assert_array_equal(activations, rescaled_activations)

  @parameterized.named_parameters(
      dict(
          testcase_name='fp_prec_scaled',
          prec=QuantOps.FloatQuant(
              is_scaled=True,
              fp_spec=QuantOps.FloatQuant.FloatPrec(
                  exp_min=-11,
                  exp_max=4,
                  sig_bits=3,
              ),
          ),
      ),
      dict(
          testcase_name='fp_prec_unscaled',
          prec=QuantOps.FloatQuant(
              is_scaled=False,
              fp_spec=QuantOps.FloatQuant.FloatPrec(
                  exp_min=-11,
                  exp_max=4,
                  sig_bits=3,
              ),
          ),
      ),
      dict(
          testcase_name='int_prec',
          prec=4.0,
      ),
  )
  def test_no_quantization(self, prec):
    # If initial_bound==-1 when using GetBounds, then create_inputs_fake_quant
    # should be a no-op.
    inputs = jnp.array([[.3, 1.4], [-5.2, 4.0]])
    bounds = get_bounds.GetBounds.Hyper(
        initial_bound=-1,
        stddev_coeff=1,
        absdev_coeff=0,
        mix_coeff=1,
        reset_stats=True,
        ema_coeff=None,
        use_cams=False,
        granularity=quant_config.QuantGranularity.per_tensor)
    hparams = quantization.QuantOps.ActHParams(
        input_distribution='symmetric',
        bounds=bounds,
        prec=prec,
        half_shift=False)

    # The call to create_inputs_fake_quant has to occur from within a Flax
    # module since it calls GetBounds, which is itself a Flax module.
    # Thus we create a wrapper module for testing.
    class TestModule(nn.Module):

      hparams: quantization.QuantOps.ActHParams

      @nn.compact
      def __call__(self, inputs):
        return quantization.QuantOps.create_inputs_fake_quant(
            inputs,
            hparams=hparams,
            get_bounds_params=GetBounds.Params(
                update_stats=True, update_bounds=False))

    test_module = TestModule(hparams=hparams)
    state = test_module.init(jax.random.PRNGKey(0), inputs=inputs)
    inputs_after_fake_quant, _ = test_module.apply(
        state, inputs=inputs, mutable=True)
    onp.testing.assert_array_equal(inputs, inputs_after_fake_quant)
Пример #7
0
from aqt.jax import quantization
from aqt.jax import shape_utils
from aqt.jax import test_utils
from aqt.jax.quantization import QuantOps
from aqt.jax.quantization import QuantType

FLAGS = flags.FLAGS

# fp-1-4-3
#    1: sign
#    4: number of exponent-bits, (bias = 11), range: -11, ..., 4
#    3: number of significand-bits (excluding hidden-bit)
fp143_scaled = QuantOps.FloatQuant(
    is_scaled=True,
    fp_spec=QuantOps.FloatQuant.FloatPrec(
        exp_min=-11,
        exp_max=4,
        sig_bits=3,
    ),
)
fp143_unscaled = QuantOps.FloatQuant(
    is_scaled=False,
    fp_spec=QuantOps.FloatQuant.FloatPrec(
        exp_min=-11,
        exp_max=4,
        sig_bits=3,
    ),
)


class ConvAqtTest(parameterized.TestCase):
  """Tests for ConvAqt layer."""