コード例 #1
0
    def test_scale_invariance_signed_activation_quantization(self, prec):
        # Scaling activation by power of 2 and bounds by same factor,
        # should scale the output by the same scale.
        activations = random.uniform(random.PRNGKey(0), (10, 1))
        act_scale = 8.
        scaled_activations = activations * act_scale

        bounds = 6.

        activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=prec))

        scaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=scaled_activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds * act_scale,
                                        prec=prec))
        onp.testing.assert_array_equal(activations * act_scale,
                                       scaled_activations)
コード例 #2
0
 def __call__(self, lhs, rhs):
     lhs_get_bounds = GetBounds.Hyper(
         initial_bound=10.0,
         stddev_coeff=0,
         absdev_coeff=0,
         mix_coeff=0,
         granularity=quant_config.QuantGranularity.per_tensor)
     rhs_get_bounds = GetBounds.Hyper(
         initial_bound=5.0,
         stddev_coeff=0,
         absdev_coeff=0,
         mix_coeff=0,
         granularity=quant_config.QuantGranularity.per_tensor)
     lhs_params = QuantOps.ActHParams(
         input_distribution='symmetric',
         bounds=lhs_get_bounds,
         prec=8)
     rhs_params = QuantOps.ActHParams(
         input_distribution='symmetric',
         bounds=rhs_get_bounds,
         prec=8)
     lhs_get_bounds_params = get_bounds.GetBounds.Params(
         update_stats=True, update_bounds=False, module_name='lhs')
     rhs_get_bounds_params = get_bounds.GetBounds.Params(
         update_stats=True, update_bounds=False, module_name='rhs')
     out = quantization.quantized_dynamic_dot_general(
         lhs_act=lhs,
         rhs_act=rhs,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=QuantType.aqt,
         lhs_get_bounds_params=lhs_get_bounds_params,
         rhs_get_bounds_params=rhs_get_bounds_params)
     return out
コード例 #3
0
 def __call__(self, inputs):
   return QuantOps.create_input_ops(
       inputs,
       hparams=hparams,
       get_bounds_params=GetBounds.Params(
           update_stats=False,
           update_bounds=False))
コード例 #4
0
    def test_int_symmetric_act_quantization(self, prec):
        # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1
        # quantizes correctly.
        bounds = 2**(prec - 1) - 1
        activations = random.randint(random.PRNGKey(0), (10, 1), -bounds,
                                     bounds)
        rescaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=prec))

        onp.testing.assert_array_equal(activations, rescaled_activations)
コード例 #5
0
    def test_int_positive_act_quantization(self, prec):
        # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec
        # quantizes correctly.
        upper_bound = 2**(prec - 3)
        activations = random.randint(random.PRNGKey(0), (10, 1), 0,
                                     upper_bound)

        rescaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.positive,
                                        bounds=upper_bound,
                                        prec=prec))
        onp.testing.assert_array_equal(activations, rescaled_activations)
コード例 #6
0
    def test_quantized_dynamic_dot_general_should_call_inputs_quantization(
            self,
            mock_act_fq,
            lhs_act_prec,
            rhs_act_prec,
            strategy=QuantType.fake_quant):
        mock_act_fq.side_effect = lambda inputs, hparams, get_bounds_params: inputs

        # pylint: disable=g-long-ternary
        lhs_act_hparams = QuantOps.ActHParams(
            bounds=6.,
            prec=lhs_act_prec,
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            half_shift=False) if lhs_act_prec else None
        rhs_act_hparams = QuantOps.ActHParams(
            bounds=6.,
            prec=rhs_act_prec,
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            half_shift=False) if rhs_act_prec else None
        # pylint: enable=g-long-ternary

        get_bounds_params = GetBounds.Params(update_stats=False,
                                             update_bounds=False)

        quantization.quantized_dynamic_dot_general(
            lhs_act=self.lhs_act,
            rhs_act=self.rhs_act,
            quant_type=strategy,
            dot_dimension_numbers=self.dimension_numbers,
            lhs_act_hparams=lhs_act_hparams,
            lhs_get_bounds_params=get_bounds_params,
            rhs_act_hparams=rhs_act_hparams,
            rhs_get_bounds_params=get_bounds_params,
        )
        calls = []
        for prec in [lhs_act_prec, rhs_act_prec]:
            if prec is not None:
                act_hparams = QuantOps.ActHParams(bounds=6.,
                                                  prec=prec,
                                                  input_distribution=mock.ANY,
                                                  half_shift=False)
                calls.append(
                    mock.call(mock.ANY,
                              hparams=act_hparams,
                              get_bounds_params=get_bounds_params))
        self.assertLen(calls, mock_act_fq.call_count)
        mock_act_fq.assert_has_calls(calls, any_order=True)
コード例 #7
0
    def test_inputs_scale_shape_is_expected(self):
        # Inputs quantization
        inputs = jnp.array(fp32(2.0 *
                                onp.random.uniform(0, 1.0, size=(10, 4))))
        bounds = 6.0
        expected_inputs_scale_shape = ()

        _ = QuantOps.create_inputs_fake_quant(
            inputs=inputs,
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=8.0),
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False,
                expected_bounds_shape=expected_inputs_scale_shape))
コード例 #8
0
  def test_quantized_dot_general_should_call_weights_and_inputs_quantization(
      self,
      mock_act_fq,
      mock_w_fq,
      weight_prec,
      act_prec,
      strategy=QuantType.fake_quant):
    mock_w_fq.side_effect = lambda inputs, **_: inputs
    mock_act_fq.side_effect = lambda inputs, **_: inputs

    weight_params = QuantOps.WeightParams(
        prec=weight_prec, axis=None, half_shift=False)
    act_hparams = QuantOps.ActHParams(  # pylint: disable=g-long-ternary
        bounds=6.,
        prec=act_prec,
        input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
        half_shift=False) if act_prec else None
    get_bounds_params = GetBounds.Params(
        update_stats=False, update_bounds=False)

    quantization.quantized_dot(
        w=self.weight,
        act=self.act,
        quant_type=strategy,
        weight_params=weight_params,
        act_hparams=act_hparams,
        get_bounds_params=get_bounds_params,
        prefer_int8_to_int32_dot=True)

    quantized_type = strategy.to_jax_type()

    mock_w_fq.assert_called_with(
        mock.ANY,
        weight_params=weight_params,
        quantized_type=quantized_type,
        fake_dependency=mock.ANY)
    if act_hparams:
      mock_act_fq.assert_called_with(
          mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)
    else:
      mock_act_fq.assert_not_called()
コード例 #9
0
 def __call__(self, inputs):
     return quantization.QuantOps.create_inputs_fake_quant(
         inputs,
         hparams=hparams,
         get_bounds_params=GetBounds.Params(update_stats=True,
                                            update_bounds=False))