def test_load_hlo_proto_from_jax_fn_and_count_ops(self, fn,
                                                   fn_args, ops_regex,
                                                   exp_count):
   hlo_proto = hlo_utils.load_hlo_proto_from_jax_fn(
       fn, *fn_args)
   count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, ops_regex=ops_regex)
   self.assertEqual(count, exp_count)
Example #2
0
    def test_number_of_floor_ops_embedding(self, num_layers,
                                           embedding_weight_prec,
                                           logits_inputs_prec,
                                           logits_inputs_hyper_is_float,
                                           logits_via_embeddings):
        # Counts number of floor ops as a proxy for quantization ops.
        if logits_inputs_hyper_is_float:
            logits_inputs_hyper = 6.0
        else:
            logits_inputs_hyper = get_bounds.GetBounds.Hyper(
                initial_bound=6.0,
                stddev_coeff=3.0,
                absdev_coeff=2.0,
                mix_coeff=0.5,
                granularity=quant_config.QuantGranularity.per_tensor)

        hparams = training_hparams_generator_lib.create_base_transformer_hparams(
            mlp_weight_prec=None,
            embedding_weight_prec=embedding_weight_prec,
            attention_weight_prec=None,
            mlp_pos_inputs_prec=None,
            mlp_pos_inputs_hyper=None,
            mlp_signed_inputs_prec=None,
            mlp_signed_inputs_hyper=None,
            attention_kqv_inputs_prec=None,
            attention_kqv_inputs_hyper=None,
            attention_out_inputs_prec=None,
            attention_out_inputs_hyper=None,
            logits_inputs_prec=logits_inputs_prec,
            logits_inputs_hyper=logits_inputs_hyper,
            logits_via_embeddings=logits_via_embeddings,
            attention_act_q_inputs_prec=None,
            attention_act_q_inputs_hyper=None,
            attention_act_k_inputs_prec=None,
            attention_act_k_inputs_hyper=None,
            attention_act_probs_inputs_prec=None,
            attention_act_v_inputs_prec=None,
            attention_act_v_inputs_hyper=None,
            num_layers=num_layers,
            emb_dim=5,
            num_heads=8,
            qkv_dim=8,
            mlp_dim=7,
            quant_type=QuantType.fake_quant)

        transformer_kwargs = self.transformer_full_kwargs
        transformer_kwargs['hparams'] = hparams
        input_shape = (2, 4)
        target_shape = input_shape
        model, init_state = self.init_model(transformer_kwargs)
        hlo_proto = hlo_utils.load_hlo_proto_from_model(
            model, init_state, [input_shape, target_shape])
        floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor')

        embedding_floor_ops = self._num_embedding_floors(
            (embedding_weight_prec is not None),
            (logits_inputs_prec is not None))

        self.assertEqual(floor_count, embedding_floor_ops)
  def test_load_hlo_proto_from_model_and_count_ops(self):
    input_shapes = [(1, 2)]
    # with nn.stateful() as init_state:
    test_model = self.TestModelWith2DenseLayers()
    init_state = test_model.init(
        random.PRNGKey(0), *[jnp.ones(shape) for shape in input_shapes])

    hlo_proto = hlo_utils.load_hlo_proto_from_model(test_model, init_state,
                                                    input_shapes)
    count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, ops_regex=r'dot')
    self.assertEqual(count, 2)
Example #4
0
 def test_count_floor_ops(self, base_config_filename, expected_floor_count):
     hparams = hparams_utils.load_hparams_from_config_dict(
         hparams_config.TrainingHParams, models.ResNet.HParams,
         base_config_filename.get_config())
     input_shape = (32, 16, 16, 3)
     model, init_state = create_model(self.rng_key,
                                      input_shape[0],
                                      input_shape[1],
                                      jnp.float32,
                                      hparams.model_hparams,
                                      train=False)
     hlo_proto = hlo_utils.load_hlo_proto_from_model(
         model, init_state, [input_shape])
     floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor')
     self.assertEqual(floor_count, expected_floor_count)
     # Expected floor count
     expected_floor_count_from_hparams = 0
     expected_floor_count_from_hparams += self._num_dense_floors(
         hparams.model_hparams)
     expected_floor_count_from_hparams += self._num_conv_floors(
         hparams.model_hparams)
     self.assertEqual(floor_count, expected_floor_count_from_hparams)
Example #5
0
    def test_number_of_floor_ops(
            self, num_layers, mlp_weight_prec, mlp_pos_inputs_prec,
            mlp_signed_inputs_prec, attention_kqv_inputs_prec,
            attention_out_inputs_prec, embedding_weight_prec,
            attention_weight_prec, logits_inputs_prec,
            attention_act_q_inputs_prec, attention_act_k_inputs_prec,
            attention_act_probs_inputs_prec, attention_act_v_inputs_prec):
        # Counts number of floor ops as a proxy for quantization ops.
        act_fixed_clip_bound = 3.0
        hparams = training_hparams_generator_lib.create_base_transformer_hparams(
            mlp_weight_prec=mlp_weight_prec,
            embedding_weight_prec=embedding_weight_prec,
            attention_weight_prec=attention_weight_prec,
            mlp_pos_inputs_prec=mlp_pos_inputs_prec,
            mlp_pos_inputs_hyper=act_fixed_clip_bound,
            mlp_signed_inputs_prec=mlp_signed_inputs_prec,
            mlp_signed_inputs_hyper=act_fixed_clip_bound,
            attention_kqv_inputs_prec=attention_kqv_inputs_prec,
            attention_kqv_inputs_hyper=act_fixed_clip_bound,
            attention_out_inputs_prec=attention_out_inputs_prec,
            attention_out_inputs_hyper=act_fixed_clip_bound,
            logits_inputs_prec=logits_inputs_prec,
            logits_inputs_hyper=act_fixed_clip_bound,
            logits_via_embeddings=True,
            attention_act_q_inputs_prec=attention_act_q_inputs_prec,
            attention_act_q_inputs_hyper=act_fixed_clip_bound,
            attention_act_k_inputs_prec=attention_act_k_inputs_prec,
            attention_act_k_inputs_hyper=act_fixed_clip_bound,
            attention_act_probs_inputs_prec=attention_act_probs_inputs_prec,
            attention_act_v_inputs_prec=attention_act_v_inputs_prec,
            attention_act_v_inputs_hyper=act_fixed_clip_bound,
            num_layers=num_layers,
            emb_dim=5,
            num_heads=8,
            qkv_dim=8,
            mlp_dim=7,
            quant_type=QuantType.fake_quant)

        transformer_kwargs = self.transformer_full_kwargs
        transformer_kwargs['hparams'] = hparams
        input_shape = (2, 4)
        target_shape = input_shape
        model, init_state = self.init_model(transformer_kwargs)
        hlo_proto = hlo_utils.load_hlo_proto_from_model(
            model, init_state, [input_shape, target_shape])
        floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor')

        mlp_floors_per_layer = self._num_mlp_floors(
            (mlp_weight_prec is not None), (mlp_pos_inputs_prec is not None),
            (mlp_signed_inputs_prec is not None))

        attention_floors_per_layer = self._num_attention_floors(
            (attention_weight_prec is not None),
            (attention_kqv_inputs_prec is not None),
            (attention_out_inputs_prec is not None),
            (attention_act_q_inputs_prec is not None),
            (attention_act_k_inputs_prec is not None),
            (attention_act_probs_inputs_prec is not None),
            (attention_act_v_inputs_prec is not None))

        embedding_floors = self._num_embedding_floors(
            (embedding_weight_prec is not None),
            (logits_inputs_prec is not None))

        expected_floor_count = num_layers * (
            mlp_floors_per_layer +
            attention_floors_per_layer) + embedding_floors
        self.assertEqual(floor_count, expected_floor_count)