def test_scale_invariance_signed_activation_quantization(self, prec):
        # Scaling activation by power of 2 and bounds by same factor,
        # should scale the output by the same scale.
        activations = random.uniform(random.PRNGKey(0), (10, 1))
        act_scale = 8.
        scaled_activations = activations * act_scale

        bounds = 6.

        activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=prec))

        scaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=scaled_activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds * act_scale,
                                        prec=prec))
        onp.testing.assert_array_equal(activations * act_scale,
                                       scaled_activations)
Beispiel #2
0
 def __call__(self, lhs_act, rhs_act, lhs_prec, rhs_prec):
     get_bounds_hyper = get_bounds.GetBounds.Hyper(
         initial_bound=10.0,
         stddev_coeff=0,
         absdev_coeff=0,
         mix_coeff=0,
         granularity=quant_config.QuantGranularity.per_tensor)
     lhs_act_hparams = QuantOps.ActHParams(
         input_distribution='symmetric',
         bounds=get_bounds_hyper,
         prec=lhs_prec,
         half_shift=False)
     rhs_act_hparams = QuantOps.ActHParams(
         input_distribution='symmetric',
         bounds=get_bounds_hyper,
         prec=rhs_prec,
         half_shift=False)
     lhs_get_bounds_params = get_bounds.GetBounds.Params(
         update_stats=False, update_bounds=False, module_name='lhs')
     rhs_get_bounds_params = get_bounds.GetBounds.Params(
         update_stats=False, update_bounds=False, module_name='rhs')
     output = quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_act_hparams,
         rhs_act_hparams=rhs_act_hparams,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=QuantType.aqt,
         lhs_get_bounds_params=lhs_get_bounds_params,
         rhs_get_bounds_params=rhs_get_bounds_params)
     return output
 def test_lax_dot_has_integer_inputs_in_dynamic_dot_general(
         self, mock_dot_general, lhs_distribution, rhs_distribution):
     lhs_params = QuantOps.ActHParams(input_distribution=lhs_distribution,
                                      bounds=2.0,
                                      prec=4)
     rhs_params = QuantOps.ActHParams(input_distribution=rhs_distribution,
                                      bounds=1.5,
                                      prec=4)
     lhs_act = self.lhs
     if lhs_distribution == 'positive':
         lhs_act = jnp.abs(lhs_act)
     rhs_act = self.rhs
     if rhs_distribution == 'positive':
         rhs_act = jnp.abs(rhs_act)
     quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=QuantType.aqt)
     lhs_inputs, rhs_inputs = mock_dot_general.call_args[0]
     self.assert_is_integer_in_range(lhs_inputs,
                                     prec=4,
                                     distribution=lhs_distribution)
     self.assert_is_integer_in_range(rhs_inputs,
                                     prec=4,
                                     distribution=rhs_distribution)
    def test_quantized_dynamic_dot_general(self, lhs_prec, rhs_prec):
        lhs_bounds = 2.0
        rhs_bounds = 1.5
        lhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                         bounds=lhs_bounds,
                                         prec=lhs_prec)
        rhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                         bounds=rhs_bounds,
                                         prec=rhs_prec)

        def quantized_matmul(quant_type):
            return quantization.quantized_dynamic_dot_general(
                lhs_act=self.lhs,
                rhs_act=self.rhs,
                lhs_act_hparams=lhs_params,
                rhs_act_hparams=rhs_params,
                lhs_get_bounds_params=None,
                rhs_get_bounds_params=None,
                dot_dimension_numbers=(((1, ), (0, )), ((), ())),
                quant_type=quant_type)

        aqt_result = quantized_matmul(QuantType.aqt)
        fakequant_result = quantized_matmul(QuantType.fake_quant)
        onp.testing.assert_allclose(
            aqt_result,
            fakequant_result,
            rtol=1e-2,
            err_msg='AQT and fakequant significantly disagree')
    def test_quantized_dot_aqt(self, act_bounds, weight_prec, weight_axis):
        # With a high enough precision, we expect results from fakequant and AQT to
        # be very similar.
        weight_params = QuantOps.WeightParams(prec=weight_prec,
                                              axis=weight_axis)

        if act_bounds is None:
            act_params = None
        else:
            act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                             bounds=jnp.array(act_bounds),
                                             prec=16)

        def quantized_matmul(quant_type):
            return quantization.quantized_dot(w=self.rhs,
                                              act=self.lhs,
                                              weight_params=weight_params,
                                              act_hparams=act_params,
                                              get_bounds_params=None,
                                              quant_type=quant_type,
                                              prefer_int8_to_int32_dot=True)

        aqt_result = quantized_matmul(QuantType.aqt)
        fakequant_result = quantized_matmul(QuantType.fake_quant)
        onp.testing.assert_allclose(
            aqt_result,
            fakequant_result,
            rtol=1e-2,
            err_msg='AQT and fakequant significantly disagree')
    def test_quantized_dynamic_dot_general_should_call_inputs_quantization(
            self,
            mock_act_fq,
            lhs_act_prec,
            rhs_act_prec,
            strategy=QuantType.fake_quant):
        mock_act_fq.side_effect = lambda inputs, hparams, get_bounds_params: inputs

        # pylint: disable=g-long-ternary
        lhs_act_hparams = QuantOps.ActHParams(
            bounds=6.,
            prec=lhs_act_prec,
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            half_shift=False) if lhs_act_prec else None
        rhs_act_hparams = QuantOps.ActHParams(
            bounds=6.,
            prec=rhs_act_prec,
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            half_shift=False) if rhs_act_prec else None
        # pylint: enable=g-long-ternary

        get_bounds_params = GetBounds.Params(update_stats=False,
                                             update_bounds=False)

        quantization.quantized_dynamic_dot_general(
            lhs_act=self.lhs_act,
            rhs_act=self.rhs_act,
            quant_type=strategy,
            dot_dimension_numbers=self.dimension_numbers,
            lhs_act_hparams=lhs_act_hparams,
            lhs_get_bounds_params=get_bounds_params,
            rhs_act_hparams=rhs_act_hparams,
            rhs_get_bounds_params=get_bounds_params,
        )
        calls = []
        for prec in [lhs_act_prec, rhs_act_prec]:
            if prec is not None:
                act_hparams = QuantOps.ActHParams(bounds=6.,
                                                  prec=prec,
                                                  input_distribution=mock.ANY,
                                                  half_shift=False)
                calls.append(
                    mock.call(mock.ANY,
                              hparams=act_hparams,
                              get_bounds_params=get_bounds_params))
        self.assertLen(calls, mock_act_fq.call_count)
        mock_act_fq.assert_has_calls(calls, any_order=True)
 def test_dynamic_quantized_dot_general_raises_with_mixed_dtype(self):
   lhs_params = QuantOps.ActHParams(
       input_distribution='symmetric', bounds=2.0, prec=4, half_shift=False)
   rhs_params = QuantOps.ActHParams(
       input_distribution='symmetric', bounds=1.5, prec=4, half_shift=False)
   lhs_act = self.lhs.astype(jnp.bfloat16)
   rhs_act = self.rhs.astype(jnp.float32)
   with self.assertRaises(TypeError):
     quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1,), (0,)), ((), ())),
         quant_type=QuantType.aqt)
  def test_attributes_create_acts_op_fp(
      self,
      act_distribution,
      use_hparams_bounds,
  ):
    inputs = jnp.array(fp32(2.0 * onp.random.uniform(0, 1.0, size=(10, 4))))
    fp_quant = QuantOps.FloatQuant(
        is_scaled=True,
        fp_spec=QuantOps.FloatQuant.FloatPrec(
            exp_min=-15,
            exp_max=15,
            sig_bits=2,
        ),
    )
    if use_hparams_bounds:
      bounds = get_bounds.GetBounds.Hyper(
          initial_bound=6.0,
          stddev_coeff=1,
          absdev_coeff=0,
          mix_coeff=1,
          reset_stats=True,
          ema_coeff=None,
          use_cams=False,
          granularity=quant_config.QuantGranularity.per_tensor)
    else:
      bounds = 6.0

    hparams = QuantOps.ActHParams(
        input_distribution=act_distribution, bounds=bounds, prec=fp_quant,
        half_shift=False)

    class TestModule(nn.Module):
      hparams: QuantOps.ActHParams

      @nn.compact
      def __call__(self, inputs):
        return QuantOps.create_input_ops(
            inputs,
            hparams=hparams,
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False))

    test_module = TestModule(hparams=hparams)
    state = test_module.init(jax.random.PRNGKey(0), inputs=inputs)
    act_quant_op = test_module.apply(state, inputs=inputs)

    act_scaled = (inputs * act_quant_op._scale).astype(inputs.dtype)
    act_quant_expected = fp_cast.downcast_sat_ftz(
        act_scaled,
        fp_quant.fp_spec.exp_min,
        fp_quant.fp_spec.exp_max,
        fp_quant.fp_spec.sig_bits,
    )
    act_quant_calculated = act_quant_op.to_quantized(inputs, dtype=SCALE_DTYPE)
    onp.testing.assert_array_equal(act_quant_expected, act_quant_calculated)
 def test_dynamic_quantized_dot_general_has_correct_dtype(
         self, input_dtype, act_prec, quant_type):
     lhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=2.0,
                                      prec=act_prec)
     rhs_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=1.5,
                                      prec=act_prec)
     lhs_act = self.lhs.astype(input_dtype)
     rhs_act = self.rhs.astype(input_dtype)
     output = quantization.quantized_dynamic_dot_general(
         lhs_act=lhs_act,
         rhs_act=rhs_act,
         lhs_act_hparams=lhs_params,
         rhs_act_hparams=rhs_params,
         lhs_get_bounds_params=None,
         rhs_get_bounds_params=None,
         dot_dimension_numbers=(((1, ), (0, )), ((), ())),
         quant_type=quant_type)
     self.assertEqual(output.dtype, input_dtype)
 def test_quantized_dot_raises_with_mixed_dtype(self, quant_type):
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=4)
     act = self.lhs.astype(jnp.bfloat16)
     w = self.rhs.astype(jnp.float32)
     with self.assertRaises(TypeError):
         quantization.quantized_dot(w=w,
                                    act=act,
                                    weight_params=weight_params,
                                    act_hparams=act_params,
                                    get_bounds_params=None,
                                    quant_type=quant_type,
                                    prefer_int8_to_int32_dot=True)
 def test_quantized_dot_no_quant(self):
     act_hparams = QuantOps.ActHParams(input_distribution='symmetric',
                                       bounds=-1.0,
                                       prec=4)
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act = jnp.array([[-5.0]])
     w = jnp.array([[-4.99]])
     res = quantization.quantized_dot(w=w,
                                      act=act,
                                      quant_type=quantization.QuantType.aqt,
                                      weight_params=weight_params,
                                      act_hparams=act_hparams,
                                      get_bounds_params=None,
                                      prefer_int8_to_int32_dot=True)
     onp.testing.assert_allclose(res, act * w)
 def test_quantized_dynamic_dot_general_no_quant(self):
   act_hparams = QuantOps.ActHParams(
       input_distribution='symmetric', bounds=-1.0, prec=4, half_shift=False)
   lhs_act = jnp.array([[-5.0]])
   rhs_act = jnp.array([[-4.99]])
   res = quantization.quantized_dynamic_dot_general(
       lhs_act=lhs_act,
       rhs_act=rhs_act,
       quant_type=quantization.QuantType.aqt,
       lhs_act_hparams=act_hparams,
       rhs_act_hparams=act_hparams,
       lhs_get_bounds_params=None,
       rhs_get_bounds_params=None,
       dot_dimension_numbers=(((1,), (0,)), ((), ())))
   onp.testing.assert_allclose(res, lhs_act * rhs_act)
 def test_quantized_dot_has_correct_dtype(self, input_dtype, act_prec,
                                          quant_type):
     weight_params = QuantOps.WeightParams(prec=4, axis=(0, ))
     act_params = QuantOps.ActHParams(input_distribution='symmetric',
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=act_prec)
     act = self.lhs.astype(input_dtype)
     w = self.rhs.astype(input_dtype)
     output = quantization.quantized_dot(w=w,
                                         act=act,
                                         weight_params=weight_params,
                                         act_hparams=act_params,
                                         get_bounds_params=None,
                                         quant_type=quant_type,
                                         prefer_int8_to_int32_dot=True)
     self.assertEqual(output.dtype, input_dtype)
    def test_int_symmetric_act_quantization(self, prec):
        # Integer activations within bounds and abs(bounds) == 2^(prec -1) - 1
        # quantizes correctly.
        bounds = 2**(prec - 1) - 1
        activations = random.randint(random.PRNGKey(0), (10, 1), -bounds,
                                     bounds)
        rescaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=prec))

        onp.testing.assert_array_equal(activations, rescaled_activations)
    def test_int_positive_act_quantization(self, prec):
        # Integer activations within upper_bound and upper_bound == 2^i s.t. i<prec
        # quantizes correctly.
        upper_bound = 2**(prec - 3)
        activations = random.randint(random.PRNGKey(0), (10, 1), 0,
                                     upper_bound)

        rescaled_activations = QuantOps.create_inputs_fake_quant(
            inputs=activations,
            get_bounds_params=GetBounds.Params(update_stats=False,
                                               update_bounds=False),
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.positive,
                                        bounds=upper_bound,
                                        prec=prec))
        onp.testing.assert_array_equal(activations, rescaled_activations)
    def test_inputs_scale_shape_is_expected(self):
        # Inputs quantization
        inputs = jnp.array(fp32(2.0 *
                                onp.random.uniform(0, 1.0, size=(10, 4))))
        bounds = 6.0
        expected_inputs_scale_shape = ()

        _ = QuantOps.create_inputs_fake_quant(
            inputs=inputs,
            hparams=QuantOps.ActHParams(input_distribution=QuantOps.ActHParams.
                                        InputDistribution.symmetric,
                                        bounds=bounds,
                                        prec=8.0),
            get_bounds_params=GetBounds.Params(
                update_stats=False,
                update_bounds=False,
                expected_bounds_shape=expected_inputs_scale_shape))
 def test_lax_dot_has_integer_inputs_in_quantized_dot(
         self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot,
         prec):
     weight_params = QuantOps.WeightParams(prec=prec,
                                           axis=(0, ),
                                           half_shift=False)
     act_params = QuantOps.ActHParams(input_distribution=act_distribution,
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=prec,
                                      half_shift=False)
     act = self.lhs
     if act_distribution == 'positive':
         act = jnp.abs(act)
     # We need this context manager to stop Jax from trying to compile the arms
     # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always
     # try to compile the functions passed to `lax.cond`, even if outside of a
     # JITed context. JIT compilation is incompatible with using a mock for the
     # call to 'dot_general' because during compilation Jax will expect
     # 'dot_general' to return a tracer and will throw an error if it returns a
     # mock instead. By explicily using jax.disable_jit, Jax will not try to
     # compile the arms to lax.cond and so using a mock will work fine.
     with jax.disable_jit():
         quantization.quantized_dot(
             w=self.rhs,
             act=act,
             weight_params=weight_params,
             act_hparams=act_params,
             get_bounds_params=None,
             quant_type=QuantType.aqt,
             prefer_int8_to_int32_dot=prefer_int8_to_int32_dot)
     act_inputs, weight_inputs = mock_dot_general.call_args[0]
     self.assert_is_integer_in_range(act_inputs,
                                     prec=prec,
                                     distribution=act_distribution)
     self.assert_is_integer_in_range(weight_inputs,
                                     prec=prec,
                                     distribution='symmetric')
     if prefer_int8_to_int32_dot and not (act_distribution == 'positive'
                                          and prec == 8):
         expected_input_dtype = jnp.int8
     else:
         expected_input_dtype = jnp.float32
     self.assertEqual(act_inputs.dtype, expected_input_dtype)
     self.assertEqual(weight_inputs.dtype, expected_input_dtype)
  def test_quantized_dot_general_should_call_weights_and_inputs_quantization(
      self,
      mock_act_fq,
      mock_w_fq,
      weight_prec,
      act_prec,
      strategy=QuantType.fake_quant):
    mock_w_fq.side_effect = lambda inputs, **_: inputs
    mock_act_fq.side_effect = lambda inputs, **_: inputs

    weight_params = QuantOps.WeightParams(
        prec=weight_prec, axis=None, half_shift=False)
    act_hparams = QuantOps.ActHParams(  # pylint: disable=g-long-ternary
        bounds=6.,
        prec=act_prec,
        input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
        half_shift=False) if act_prec else None
    get_bounds_params = GetBounds.Params(
        update_stats=False, update_bounds=False)

    quantization.quantized_dot(
        w=self.weight,
        act=self.act,
        quant_type=strategy,
        weight_params=weight_params,
        act_hparams=act_hparams,
        get_bounds_params=get_bounds_params,
        prefer_int8_to_int32_dot=True)

    quantized_type = strategy.to_jax_type()

    mock_w_fq.assert_called_with(
        mock.ANY,
        weight_params=weight_params,
        quantized_type=quantized_type,
        fake_dependency=mock.ANY)
    if act_hparams:
      mock_act_fq.assert_called_with(
          mock.ANY, hparams=act_hparams, get_bounds_params=get_bounds_params)
    else:
      mock_act_fq.assert_not_called()
  def test_quantized_dot_general_aqt(self, act_bounds, weight_prec,
                                     weight_axis):
    # With a high enough precision, we expect results from fakequant and AQT to
    # be very similar.
    weight_params = QuantOps.WeightParams(
        prec=weight_prec, axis=weight_axis, half_shift=False)

    if act_bounds is None:
      act_params = None
    else:
      act_params = QuantOps.ActHParams(
          input_distribution='symmetric',
          bounds=jnp.array(act_bounds),
          prec=16,
          half_shift=False)

    lhs_ndims_3 = jnp.array(
        fp32(2.0 * onp.random.uniform(0, 1.0, size=(4, 3, 2))))

    def quantized_matmul(quant_type):
      return quantization.quantized_dot_general(
          w=self.rhs,
          act=lhs_ndims_3,
          weight_params=weight_params,
          act_hparams=act_params,
          get_bounds_params=None,
          quant_type=quant_type,
          dimension_numbers=(((lhs_ndims_3.ndim - 1,), (0,)), ((), ())),
          prefer_int8_to_int32_dot=True)

    aqt_result = quantized_matmul(QuantType.aqt)
    self.assertEqual(aqt_result.shape, (4, 3, 4))

    fakequant_result = quantized_matmul(QuantType.fake_quant)
    onp.testing.assert_allclose(
        aqt_result,
        fakequant_result,
        rtol=1e-2,
        err_msg='AQT and fakequant significantly disagree')
Beispiel #20
0
class ComputeCostUtilsTest(parameterized.TestCase):
    def setUp(self):
        super(ComputeCostUtilsTest, self).setUp()
        self.rng_key = random.PRNGKey(0)

    def compare_hlo_instructions(self, hlo_no_annotation, hlo_w_annotation):
        """Compares two HLO models to check if they only differ in metadata info."""
        instrs_n = []
        instrs_w = []
        # gather instructions from both HLO models
        for computation in hlo_no_annotation.computations:
            for instr in computation.instructions:
                instrs_n.append(instr)
        for computation in hlo_w_annotation.computations:
            for instr in computation.instructions:
                instrs_w.append(instr)

        self.assertEqual(len(instrs_n), len(instrs_w))
        for i, _ in enumerate(instrs_n):
            # check instructions with the opcode 'convolution'
            # the metadata field for instrs_w and instrs_n should be different.
            if (instrs_n[i].opcode == 'convolution'
                    and instrs_w[i].opcode == 'convolution'):
                self.assertNotEqual(instrs_n[i].metadata, instrs_w[i].metadata)

            # remove metadata op_type and op_name
            instrs_n[i].metadata.op_type = ''
            instrs_w[i].metadata.op_type = ''
            instrs_n[i].metadata.op_name = ''
            instrs_w[i].metadata.op_name = ''
            # compare the rest of the instructions.
            self.assertEqual(instrs_n[i], instrs_w[i])

    class TestModelWith1Dense(nn.Module):
        """Test model with a single DenseAqt layer."""
        @nn.compact
        def __call__(self, inputs, hparams, num_classes, dtype=jnp.float32):
            output = aqt_flax_layers.DenseAqt(
                features=num_classes,
                dtype=dtype,
                train=False,
                quant_context=quant_config.QuantContext(
                    update_bounds=False, collect_acts_stats=False),
                paxis_name='batch',
                hparams=hparams,
            )(inputs, padding_mask=None)
            return output

    class TestModelWith1Conv(nn.Module):
        """Test model with a single ConvAqt layer."""
        @nn.compact
        def __call__(self,
                     inputs,
                     hparams,
                     kernel_size,
                     num_filters,
                     strides,
                     dtype=jnp.float32):
            output = aqt_flax_layers.ConvAqt(
                features=num_filters,
                kernel_size=kernel_size,
                strides=strides,
                use_bias=False,
                dtype=dtype,
                train=False,
                quant_context=quant_config.QuantContext(update_bounds=False),
                paxis_name='batch',
                hparams=hparams)(inputs)
            return output

    class TestModelWith1DynamicMatmul(nn.Module):
        """Test model with a single dynamic matmul."""
        @nn.compact
        def __call__(self, lhs_act, rhs_act, lhs_prec, rhs_prec):
            get_bounds_hyper = get_bounds.GetBounds.Hyper(
                initial_bound=10.0,
                stddev_coeff=0,
                absdev_coeff=0,
                mix_coeff=0,
                granularity=quant_config.QuantGranularity.per_tensor)
            lhs_act_hparams = QuantOps.ActHParams(
                input_distribution='symmetric',
                bounds=get_bounds_hyper,
                prec=lhs_prec,
                half_shift=False)
            rhs_act_hparams = QuantOps.ActHParams(
                input_distribution='symmetric',
                bounds=get_bounds_hyper,
                prec=rhs_prec,
                half_shift=False)
            lhs_get_bounds_params = get_bounds.GetBounds.Params(
                update_stats=False, update_bounds=False, module_name='lhs')
            rhs_get_bounds_params = get_bounds.GetBounds.Params(
                update_stats=False, update_bounds=False, module_name='rhs')
            output = quantization.quantized_dynamic_dot_general(
                lhs_act=lhs_act,
                rhs_act=rhs_act,
                lhs_act_hparams=lhs_act_hparams,
                rhs_act_hparams=rhs_act_hparams,
                dot_dimension_numbers=(((1, ), (0, )), ((), ())),
                quant_type=QuantType.aqt,
                lhs_get_bounds_params=lhs_get_bounds_params,
                rhs_get_bounds_params=rhs_get_bounds_params)
            return output

    @parameterized.named_parameters(
        # TestModelWith1Dense
        dict(
            testcase_name='single_dense_layer_bfloat16',
            modelclass=TestModelWith1Dense,
            input_shapes=[(1, 8)],
            model_kwargs={
                'num_classes': 2,
                'hparams': aqt_flax_layers.DenseAqt.HParams(
                    weight_prec=None,
                    quant_type=QuantType.fake_quant,
                    quant_act=None,
                    weight_quant_granularity=quant_config.QuantGranularity.per_channel,
                    weight_half_shift=False
                ),
            },
            expected_compute_cost=8 * 2 * (16 * 16),
            expected_compute_cost_ratio=1.0,
            expected_compute_cost_linear=8 * 2 * (16),
            expected_compute_cost_ratio_linear=1.0,
            expected_memory_cost=8 * 2 * (16),
            expected_memory_cost_ratio=1.0,
        ),
        dict(
            testcase_name='single_dense_layer_w8_a8',
            modelclass=TestModelWith1Dense,
            input_shapes=[(1, 8)],
            model_kwargs={
                'num_classes': 2,
                'hparams': aqt_flax_layers.DenseAqt.HParams(
                    weight_prec=8,
                    quant_type=QuantType.fake_quant,
                    quant_act=QuantOps.ActHParams(
                        input_distribution=QuantOps.ActHParams.InputDistribution.positive,
                        prec=8,
                        bounds=1.0,
                        half_shift=False,
                    ),
                    weight_quant_granularity=quant_config.QuantGranularity.per_channel,
                    weight_half_shift=False
                ),
            },
            expected_compute_cost=8 * 2 * (8 * 8),
            expected_compute_cost_ratio=0.25,
            expected_compute_cost_linear=8 * 2 * (8),
            expected_compute_cost_ratio_linear=0.5,
            expected_memory_cost=8 * 2 * (8),
            expected_memory_cost_ratio=0.5,
        ),

        # TestModelWith1Conv
        dict(
            testcase_name='single_conv_layer_bfloat16',
            modelclass=TestModelWith1Conv,
            input_shapes=[(1, 8, 8, 3)],
            model_kwargs={
                'kernel_size': (3, 3),
                'num_filters': 16,
                'strides': (1, 1),
                'hparams': aqt_flax_layers.ConvAqt.HParams(
                    weight_prec=None,
                    quant_type=QuantType.fake_quant,
                    quant_act=None,
                    weight_half_shift=False,
                ),
            },
            expected_compute_cost=(3 * 3) * (8 * 8) * 3 * 16 * (16 * 16),
            expected_compute_cost_ratio=1.0,
            expected_compute_cost_linear=(3 * 3) * (8 * 8) * 3 * 16 * (16),
            expected_compute_cost_ratio_linear=1.0,
            expected_memory_cost=(3 * 3) * 3 * 16 * (16),
            expected_memory_cost_ratio=1.0,
        ),
        dict(
            testcase_name='single_conv_layer_bfloat16_strided',
            modelclass=TestModelWith1Conv,
            input_shapes=[(1, 8, 8, 3)],
            model_kwargs={
                'kernel_size': (3, 3),
                'num_filters': 16,
                'strides': (4, 2),
                'hparams': aqt_flax_layers.ConvAqt.HParams(
                    weight_prec=None,
                    quant_type=QuantType.fake_quant,
                    quant_act=None,
                    weight_half_shift=False,
                ),
            },
            expected_compute_cost=(3 * 3) * ((8 / 4) * (8 / 2)) * 3 * 16 * (16 * 16),
            expected_compute_cost_ratio=1.0,
            expected_compute_cost_linear=(3 * 3) * ((8 / 4) * (8 / 2)) * 3 * 16 * (16),
            expected_compute_cost_ratio_linear=1.0,
            expected_memory_cost=(3 * 3) * 3 * 16 * (16),
            expected_memory_cost_ratio=1.0,
        ),
        dict(
            testcase_name='single_conv_layer_bfloat16_3d',
            modelclass=TestModelWith1Conv,
            input_shapes=[(1, 8, 8, 8, 3)],
            model_kwargs={
                'kernel_size': (3, 3, 3),
                'num_filters': 16,
                'strides': (1, 1, 1),
                'hparams': aqt_flax_layers.ConvAqt.HParams(
                    weight_prec=None,
                    quant_type=QuantType.fake_quant,
                    quant_act=None,
                    weight_half_shift=False,
                ),
            },
            expected_compute_cost=(3 * 3 * 3) * (8 * 8 * 8) * 3 * 16 * (16 * 16),
            expected_compute_cost_ratio=1.0,
            expected_compute_cost_linear=(3 * 3 * 3) * (8 * 8 * 8) * 3 * 16 * (16),
            expected_compute_cost_ratio_linear=1.0,
            expected_memory_cost=(3 * 3 * 3) * 3 * 16 * (16),
            expected_memory_cost_ratio=1.0,
        ),
        dict(
            testcase_name='single_conv_layer_w4_a2',
            modelclass=TestModelWith1Conv,
            input_shapes=[(1, 8, 8, 3)],
            model_kwargs={
                'kernel_size': (3, 3),
                'num_filters': 16,
                'strides': (1, 1),
                'hparams': aqt_flax_layers.ConvAqt.HParams(
                    weight_prec=4,
                    quant_type=QuantType.fake_quant,
                    quant_act=QuantOps.ActHParams(
                        input_distribution=QuantOps.ActHParams.InputDistribution.positive,
                        prec=2,
                        bounds=1.0,
                        half_shift=False,
                    ),
                    weight_half_shift=False,
                ),
            },
            expected_compute_cost=(3 * 3) * (8 * 8) * 3 * 16 * (4 * 2),
            expected_compute_cost_ratio=0.03125,
            expected_compute_cost_linear=(3 * 3) * (8 * 8) * 3 * 16 * (4),
            expected_compute_cost_ratio_linear=0.25,
            expected_memory_cost=(3 * 3) * 3 * 16 * (4),
            expected_memory_cost_ratio=0.25,
        ),
        # TestModelWith1DynamicMatmul
        dict(
            testcase_name='single_dynamic_matmul_layer_bfloat16',
            modelclass=TestModelWith1DynamicMatmul,
            input_shapes=[(1, 8), (8, 1)],
            model_kwargs={'lhs_prec': None,
                          'rhs_prec': None},
            expected_compute_cost=8 * (16 * 16),
            expected_compute_cost_ratio=1.0,
            expected_compute_cost_linear=8 * (16),
            expected_compute_cost_ratio_linear=1.0,
            expected_memory_cost=0,
            expected_memory_cost_ratio=1.0,
        ),
        dict(
            testcase_name='single_dynamic_matmul_layer_l8_r8',
            modelclass=TestModelWith1DynamicMatmul,
            input_shapes=[(1, 8), (8, 1)],
            model_kwargs={'lhs_prec': 8,
                          'rhs_prec': 8},
            expected_compute_cost=8 * (8 * 8),
            expected_compute_cost_ratio=0.25,
            expected_compute_cost_linear=8 * 8,
            expected_compute_cost_ratio_linear=0.5,
            expected_memory_cost=0,
            expected_memory_cost_ratio=1.0,
        ),
        dict(
            testcase_name='single_dynamic_matmul_layer_l8_r4',
            modelclass=TestModelWith1DynamicMatmul,
            input_shapes=[(1, 8), (8, 1)],
            model_kwargs={'lhs_prec': 8,
                          'rhs_prec': 4},
            expected_compute_cost=8 * (8 * 4),
            expected_compute_cost_ratio=0.125,
            expected_compute_cost_linear=8 * (8),
            expected_compute_cost_ratio_linear=0.5,
            expected_memory_cost=0,
            expected_memory_cost_ratio=1.0,
        ),
    )  # pylint: disable=line-too-long
    def test_estimate_simple_model_cost(
            self, modelclass, input_shapes, model_kwargs,
            expected_compute_cost, expected_compute_cost_ratio,
            expected_compute_cost_linear, expected_compute_cost_ratio_linear,
            expected_memory_cost, expected_memory_cost_ratio):
        module = modelclass()
        input_shapes_with_type = [(sh, jnp.float32) for sh in input_shapes]
        dummy_inputs = [
            jnp.ones(input_shape, dtype=dtype)
            for (input_shape, dtype) in input_shapes_with_type
        ]
        init_state = module.init(random.PRNGKey(0), *dummy_inputs,
                                 **model_kwargs)

        hlo_proto = hlo_utils.load_hlo_proto_from_model(
            module, init_state, input_shapes, **model_kwargs)
        compute_result = compute_cost_utils.estimate_compute_cost(hlo_proto)
        memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto)
        logging.info('compute cost result is %s', compute_result)
        logging.info('memory cost result is %s', memory_result)
        self.assertEqual(compute_result['compute_cost'], expected_compute_cost)
        self.assertEqual(memory_result['memory_cost'], expected_memory_cost)
        self.assertEqual(compute_result['compute_cost_ratio_to_bfloat16'],
                         expected_compute_cost_ratio)
        self.assertEqual(memory_result['memory_cost_ratio_to_bfloat16'],
                         expected_memory_cost_ratio)
        self.assertEqual(compute_result['compute_cost_linear'],
                         expected_compute_cost_linear)
        self.assertEqual(
            compute_result['compute_cost_ratio_to_bfloat16_linear'],
            expected_compute_cost_ratio_linear)

    @parameterized.named_parameters(
        # TestModelWith1Dense
        dict(
            testcase_name='single_dense_layer_bfloat16_batch_size',
            modelclass=TestModelWith1Dense,
            input_shape_per_sample=(16, ),
            model_kwargs={
                'num_classes':
                20,
                'hparams':
                aqt_flax_layers.DenseAqt.HParams(
                    weight_prec=None,
                    quant_act=None,
                    quant_type=QuantType.fake_quant,
                    weight_quant_granularity=quant_config.QuantGranularity.
                    per_channel,
                    weight_half_shift=False)
            },
        ),
        # TestModelWith1Conv
        dict(
            testcase_name='single_conv_layer_bfloat16_batch_size',
            modelclass=TestModelWith1Conv,
            input_shape_per_sample=(16, 16, 3),
            model_kwargs={
                'kernel_size': (3, 3),
                'num_filters':
                16,
                'strides': (2, 2),
                'hparams':
                aqt_flax_layers.ConvAqt.HParams(
                    weight_prec=None,
                    quant_act=None,
                    quant_type=QuantType.fake_quant,
                    weight_half_shift=False,
                )
            },
        ),
    )
    def test_batch_size_has_no_effect_on_cost(self, modelclass,
                                              input_shape_per_sample,
                                              model_kwargs):
        expected_compute_cost = None
        expected_memory_cost = None
        batch_size_list = [32, 64, 128, 256, 512, 1024]

        module = modelclass()

        # Sweep over the batch size list
        for batch_size in batch_size_list:
            input_shape = (batch_size, ) + input_shape_per_sample
            init_state = module.init(random.PRNGKey(0),
                                     jnp.ones(input_shape, jnp.float32),
                                     **model_kwargs)
            hlo_proto = hlo_utils.load_hlo_proto_from_model(
                module, init_state, [input_shape], **model_kwargs)
            del init_state
            compute_result = compute_cost_utils.estimate_compute_cost(
                hlo_proto)
            memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto)
            # Save the first cost and compare it with the rest
            if expected_compute_cost is None:
                expected_compute_cost = compute_result['compute_cost']
            else:
                self.assertEqual(compute_result['compute_cost'],
                                 expected_compute_cost)
            if expected_memory_cost is None:
                expected_memory_cost = memory_result['memory_cost']
            else:
                self.assertEqual(memory_result['memory_cost'],
                                 expected_memory_cost)

    @parameterized.named_parameters(
        dict(testcase_name='quant_8bit', weight_prec=8),
        dict(testcase_name='quant_4bit', weight_prec=4),
    )
    def test_check_value_inside_and_outside_of_context_conv_general(
            self, weight_prec):
        original_op_name = 'conv_general_dilated'
        # The 'name' in primitive should change in the context in 'flax_layers'
        # if the context is enabled
        self.assertEqual(original_op_name,
                         lax_convolution.conv_general_dilated_p.name)

        with compute_cost_utils.ConvMetadataMonkeyPatch(
                weight_prec=weight_prec, act_prec=None):
            self.assertNotEqual(original_op_name,
                                lax_convolution.conv_general_dilated_p.name)
        self.assertEqual(original_op_name,
                         lax_convolution.conv_general_dilated_p.name)

    @parameterized.named_parameters(
        dict(testcase_name='quant_8bit', weight_prec=8, acts_prec=8),
        dict(testcase_name='quant_4bit', weight_prec=4, acts_prec=4),
    )
    def test_annotation_only_changes_hlo_metadata_conv(self, weight_prec,
                                                       acts_prec):
        FLAGS.metadata_enabled = False
        quant_act = quantization.QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            prec=acts_prec,
            bounds=1.0,
            half_shift=False)
        input_shape = (1, 8, 8, 3)
        module_no_annotation = aqt_flax_layers.ConvAqt(
            features=4,
            kernel_size=(3, 3),
            padding='VALID',
            paxis_name='batch',
            quant_context=quant_config.QuantContext(update_bounds=False),
            train=False,
            hparams=aqt_flax_layers.ConvAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_half_shift=False),
            kernel_init=initializers.ones,
            bias_init=initializers.ones,
            dtype=jnp.float32)

        init_state = module_no_annotation.init(
            self.rng_key, jnp.ones(input_shape, jnp.float32))
        output_no_annotation = module_no_annotation.apply(
            init_state, jnp.ones(input_shape))

        hlo_no_annotation = hlo_utils.load_hlo_proto_from_model(
            module_no_annotation, init_state, [input_shape])
        del init_state

        FLAGS.metadata_enabled = True
        module_w_annotation = aqt_flax_layers.ConvAqt(
            features=4,
            kernel_size=(3, 3),
            padding='VALID',
            paxis_name='batch',
            quant_context=quant_config.QuantContext(update_bounds=False),
            train=False,
            hparams=aqt_flax_layers.ConvAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_half_shift=False),
            kernel_init=initializers.ones,
            bias_init=initializers.ones,
            dtype=jnp.float32)

        init_state = module_w_annotation.init(
            self.rng_key, jnp.ones(input_shape, jnp.float32))
        output_w_annotation = module_w_annotation.apply(
            init_state, jnp.ones(input_shape))

        hlo_w_annotation = hlo_utils.load_hlo_proto_from_model(
            module_w_annotation, init_state, [input_shape])
        del init_state

        onp.testing.assert_array_equal(output_no_annotation,
                                       output_w_annotation)
        self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)

    @parameterized.named_parameters(
        dict(testcase_name='quant_8bit', weight_prec=8),
        dict(testcase_name='quant_4bit', weight_prec=4),
    )
    def test_check_value_inside_and_outside_of_context_dot_general(
            self, weight_prec):
        original_op_name = 'dot_general'
        # The 'name' in primitive should change in the context in 'flax_layers'
        # if the context is enabled.
        self.assertEqual(original_op_name, lax.dot_general_p.name)

        with compute_cost_utils.DotMetadataMonkeyPatch(lhs_prec=None,
                                                       rhs_prec=weight_prec,
                                                       rhs_is_weight=True):
            self.assertNotEqual(original_op_name, lax.dot_general_p.name)
        self.assertEqual(original_op_name, lax.dot_general_p.name)

    @parameterized.named_parameters(
        dict(
            testcase_name='quant_8bit',
            weight_prec=8,
            acts_prec=8,
        ), )
    def test_annotation_only_changes_hlo_metadata_dense(
            self, weight_prec, acts_prec):
        FLAGS.metadata_enabled = False
        quant_act = quantization.QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            prec=acts_prec,
            bounds=1.0,
            half_shift=False)
        input_shape = (1, 16)
        module_no_annotation = aqt_flax_layers.DenseAqt(
            features=4,
            use_bias=False,
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            paxis_name='batch',
            train=False,
            hparams=aqt_flax_layers.DenseAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_quant_granularity=quant_config.QuantGranularity.
                per_channel,
                weight_half_shift=False),
            dtype=jnp.float32)

        init_state = module_no_annotation.init(self.rng_key,
                                               jnp.ones(
                                                   input_shape, jnp.float32),
                                               padding_mask=None)
        output_no_annotation = module_no_annotation.apply(
            init_state, jnp.ones(input_shape), padding_mask=None)

        hlo_no_annotation = hlo_utils.load_hlo_proto_from_model(
            module_no_annotation, init_state, [input_shape], padding_mask=None)
        del init_state

        FLAGS.metadata_enabled = True
        module_w_annotation = aqt_flax_layers.DenseAqt(
            features=4,
            use_bias=False,
            paxis_name='batch',
            train=False,
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            dtype=jnp.float32,
            hparams=aqt_flax_layers.DenseAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_quant_granularity=quant_config.QuantGranularity.
                per_channel,
                weight_half_shift=False),
        )

        init_state = module_w_annotation.init(self.rng_key,
                                              jnp.ones(input_shape,
                                                       jnp.float32),
                                              padding_mask=None)
        output_w_annotation = module_w_annotation.apply(init_state,
                                                        jnp.ones(input_shape),
                                                        padding_mask=None)

        hlo_w_annotation = hlo_utils.load_hlo_proto_from_model(
            module_w_annotation, init_state, [input_shape], padding_mask=None)
        del init_state

        onp.testing.assert_array_equal(output_no_annotation,
                                       output_w_annotation)
        self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
Beispiel #21
0
class AttnActsMatmulQuantTest(parameterized.TestCase):
    def construct_hparams(self, attn_act_q, attn_act_k, attn_act_probs,
                          attn_act_v):
        dense = flax_layers.DenseAqt.HParams(
            weight_prec=None,
            quant_act=None,
            quant_type=QuantType.fake_quant,
            weight_quant_granularity=quant_config.QuantGranularity.per_channel,
            weight_half_shift=False)
        return flax_attention.MultiHeadDotProductAttentionAqt.HParams(
            dense_kqv=dense,
            dense_out=dense,
            attn_acts=flax_attention.DotProductAttnHParams(
                attn_act_q=attn_act_q,
                attn_act_k=attn_act_k,
                attn_act_probs=attn_act_probs,
                attn_act_v=attn_act_v,
                quant_type=QuantType.fake_quant,
                softmax=SoftmaxHParams(None, None, None)))

    @parameterized.named_parameters(
        dict(testcase_name='float',
             attn_act_q=None,
             attn_act_k=None,
             attn_act_probs=None,
             attn_act_v=None,
             update_bounds=False,
             paxis_name=None,
             train=False),
        dict(testcase_name='quant_q',
             attn_act_q=QuantOps.ActHParams(
                 input_distribution=QuantOps.ActHParams.InputDistribution.
                 symmetric,
                 prec=8,
                 bounds=1,
                 half_shift=False),
             attn_act_k=None,
             attn_act_probs=None,
             attn_act_v=None,
             update_bounds=False,
             paxis_name='batch',
             train=True),
        dict(testcase_name='quant_qk',
             attn_act_q=None,
             attn_act_k=None,
             attn_act_probs=QuantOps.ActHParams(
                 input_distribution=QuantOps.ActHParams.InputDistribution.
                 symmetric,
                 prec=8,
                 bounds=1.0,
                 half_shift=False),
             attn_act_v=None,
             update_bounds=False,
             paxis_name='batch',
             train=True),
        dict(testcase_name='quant_k',
             attn_act_q=None,
             attn_act_k=QuantOps.ActHParams(
                 input_distribution=QuantOps.ActHParams.InputDistribution.
                 symmetric,
                 prec=4,
                 bounds=2,
                 half_shift=False),
             attn_act_probs=None,
             attn_act_v=None,
             update_bounds=False,
             paxis_name=None,
             train=True),
        dict(testcase_name='quant_v',
             attn_act_q=None,
             attn_act_k=None,
             attn_act_probs=None,
             attn_act_v=QuantOps.ActHParams(
                 input_distribution=QuantOps.ActHParams.InputDistribution.
                 symmetric,
                 prec=2,
                 bounds=3,
                 half_shift=False),
             update_bounds=True,
             paxis_name='batch',
             train=False),
        dict(testcase_name='quant_all_aa',
             attn_act_q=QuantOps.ActHParams(
                 input_distribution=QuantOps.ActHParams.InputDistribution.
                 symmetric,
                 prec=8,
                 bounds=1,
                 half_shift=False),
             attn_act_k=QuantOps.ActHParams(
                 input_distribution=QuantOps.ActHParams.InputDistribution.
                 symmetric,
                 prec=4,
                 bounds=2,
                 half_shift=False),
             attn_act_probs=QuantOps.ActHParams(
                 input_distribution=QuantOps.ActHParams.InputDistribution.
                 symmetric,
                 prec=8,
                 bounds=1.0,
                 half_shift=False),
             attn_act_v=QuantOps.ActHParams(
                 input_distribution=QuantOps.ActHParams.InputDistribution.
                 symmetric,
                 prec=2,
                 bounds=3,
                 half_shift=False),
             update_bounds=True,
             paxis_name=None,
             train=True),
    )
    @unittest.mock.patch.object(QuantOps, 'create_inputs_fake_quant')
    def test_self_attention_act_quant_should_call_quant_ops(
            self, mock_inputs_fake_quant, attn_act_q, attn_act_k,
            attn_act_probs, attn_act_v, update_bounds, paxis_name, train):

        mock_inputs_fake_quant.side_effect = (
            lambda inputs, hparams, get_bounds_params: inputs)

        rng = random.PRNGKey(0)
        x = jnp.ones((4, 3, 7))
        hparams = self.construct_hparams(attn_act_q, attn_act_k,
                                         attn_act_probs, attn_act_v)
        sa_module = flax_attention.SelfAttentionAqt(
            hparams=hparams,
            num_heads=4,
            quant_context=quant_config.QuantContext(
                update_bounds=update_bounds, collect_acts_stats=False),
            train=train,
            paxis_name=paxis_name,
            attention_axis=None,
            qkv_features=8,
            kernel_init=initializers.ones,
            bias_init=initializers.zeros,
            causal_mask=False,
            dtype=jnp.float32,
            dropout_rate=0.0,
            deterministic=False,
            decode=False)
        sa_module.init(rng, x, padding_mask=None)
        calls = []
        for hparam in [attn_act_q, attn_act_k, attn_act_probs, attn_act_v]:
            if hparam is not None:
                calls.append(
                    unittest.mock.call(
                        unittest.mock.ANY,
                        hparams=hparam,
                        get_bounds_params=get_bounds.GetBounds.Params(
                            update_stats=train,
                            update_bounds=update_bounds,
                            paxis_name=paxis_name,
                            mask=unittest.mock.ANY,
                            module_name=unittest.mock.ANY)))
        mock_inputs_fake_quant.assert_has_calls(calls, any_order=True)

        self.assertLen(calls, mock_inputs_fake_quant.call_count)