Exemplo n.º 1
0
    def test_annotation_only_changes_hlo_metadata_conv(self, weight_prec,
                                                       acts_prec):
        FLAGS.metadata_enabled = False
        quant_act = quantization.QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            prec=acts_prec,
            bounds=1.0)
        input_shape = (1, 8, 8, 3)
        module_no_annotation = aqt_flax_layers.ConvAqt(
            features=4,
            kernel_size=(3, 3),
            padding='VALID',
            paxis_name='batch',
            quant_context=quant_config.QuantContext(update_bounds=False),
            train=False,
            hparams=aqt_flax_layers.ConvAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant),
            kernel_init=initializers.ones,
            bias_init=initializers.ones,
            dtype=jnp.float32)

        init_state = module_no_annotation.init(
            self.rng_key, jnp.ones(input_shape, jnp.float32))
        output_no_annotation = module_no_annotation.apply(
            init_state, jnp.ones(input_shape))

        hlo_no_annotation = hlo_utils.load_hlo_proto_from_model(
            module_no_annotation, init_state, [input_shape])
        del init_state

        FLAGS.metadata_enabled = True
        module_w_annotation = aqt_flax_layers.ConvAqt(
            features=4,
            kernel_size=(3, 3),
            padding='VALID',
            paxis_name='batch',
            quant_context=quant_config.QuantContext(update_bounds=False),
            train=False,
            hparams=aqt_flax_layers.ConvAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant),
            kernel_init=initializers.ones,
            bias_init=initializers.ones,
            dtype=jnp.float32)

        init_state = module_w_annotation.init(
            self.rng_key, jnp.ones(input_shape, jnp.float32))
        output_w_annotation = module_w_annotation.apply(
            init_state, jnp.ones(input_shape))

        hlo_w_annotation = hlo_utils.load_hlo_proto_from_model(
            module_w_annotation, init_state, [input_shape])
        del init_state

        onp.testing.assert_array_equal(output_no_annotation,
                                       output_w_annotation)
        self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
Exemplo n.º 2
0
    def test_batch_size_has_no_effect_on_cost(self, modelclass,
                                              input_shape_per_sample,
                                              model_kwargs):
        expected_compute_cost = None
        expected_memory_cost = None
        batch_size_list = [32, 64, 128, 256, 512, 1024]

        module = modelclass()

        # Sweep over the batch size list
        for batch_size in batch_size_list:
            input_shape = (batch_size, ) + input_shape_per_sample
            init_state = module.init(random.PRNGKey(0),
                                     jnp.ones(input_shape, jnp.float32),
                                     **model_kwargs)
            hlo_proto = hlo_utils.load_hlo_proto_from_model(
                module, init_state, [input_shape], **model_kwargs)
            del init_state
            compute_result = compute_cost_utils.estimate_compute_cost(
                hlo_proto)
            memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto)
            # Save the first cost and compare it with the rest
            if expected_compute_cost is None:
                expected_compute_cost = compute_result['compute_cost']
            else:
                self.assertEqual(compute_result['compute_cost'],
                                 expected_compute_cost)
            if expected_memory_cost is None:
                expected_memory_cost = memory_result['memory_cost']
            else:
                self.assertEqual(memory_result['memory_cost'],
                                 expected_memory_cost)
Exemplo n.º 3
0
    def test_estimate_simple_model_cost(
            self, modelclass, input_shapes, model_kwargs,
            expected_compute_cost, expected_compute_cost_ratio,
            expected_compute_cost_linear, expected_compute_cost_ratio_linear,
            expected_memory_cost, expected_memory_cost_ratio):
        module = modelclass()
        input_shapes_with_type = [(sh, jnp.float32) for sh in input_shapes]
        dummy_inputs = [
            jnp.ones(input_shape, dtype=dtype)
            for (input_shape, dtype) in input_shapes_with_type
        ]
        init_state = module.init(random.PRNGKey(0), *dummy_inputs,
                                 **model_kwargs)

        hlo_proto = hlo_utils.load_hlo_proto_from_model(
            module, init_state, input_shapes, **model_kwargs)
        compute_result = compute_cost_utils.estimate_compute_cost(hlo_proto)
        memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto)
        logging.info('compute cost result is %s', compute_result)
        logging.info('memory cost result is %s', memory_result)
        self.assertEqual(compute_result['compute_cost'], expected_compute_cost)
        self.assertEqual(memory_result['memory_cost'], expected_memory_cost)
        self.assertEqual(compute_result['compute_cost_ratio_to_bfloat16'],
                         expected_compute_cost_ratio)
        self.assertEqual(memory_result['memory_cost_ratio_to_bfloat16'],
                         expected_memory_cost_ratio)
        self.assertEqual(compute_result['compute_cost_linear'],
                         expected_compute_cost_linear)
        self.assertEqual(
            compute_result['compute_cost_ratio_to_bfloat16_linear'],
            expected_compute_cost_ratio_linear)
Exemplo n.º 4
0
    def test_number_of_floor_ops_embedding(self, num_layers,
                                           embedding_weight_prec,
                                           logits_inputs_prec,
                                           logits_inputs_hyper_is_float,
                                           logits_via_embeddings):
        # Counts number of floor ops as a proxy for quantization ops.
        if logits_inputs_hyper_is_float:
            logits_inputs_hyper = 6.0
        else:
            logits_inputs_hyper = get_bounds.GetBounds.Hyper(
                initial_bound=6.0,
                stddev_coeff=3.0,
                absdev_coeff=2.0,
                mix_coeff=0.5,
                granularity=quant_config.QuantGranularity.per_tensor)

        hparams = training_hparams_generator_lib.create_base_transformer_hparams(
            mlp_weight_prec=None,
            embedding_weight_prec=embedding_weight_prec,
            attention_weight_prec=None,
            mlp_pos_inputs_prec=None,
            mlp_pos_inputs_hyper=None,
            mlp_signed_inputs_prec=None,
            mlp_signed_inputs_hyper=None,
            attention_kqv_inputs_prec=None,
            attention_kqv_inputs_hyper=None,
            attention_out_inputs_prec=None,
            attention_out_inputs_hyper=None,
            logits_inputs_prec=logits_inputs_prec,
            logits_inputs_hyper=logits_inputs_hyper,
            logits_via_embeddings=logits_via_embeddings,
            attention_act_q_inputs_prec=None,
            attention_act_q_inputs_hyper=None,
            attention_act_k_inputs_prec=None,
            attention_act_k_inputs_hyper=None,
            attention_act_probs_inputs_prec=None,
            attention_act_v_inputs_prec=None,
            attention_act_v_inputs_hyper=None,
            num_layers=num_layers,
            emb_dim=5,
            num_heads=8,
            qkv_dim=8,
            mlp_dim=7,
            quant_type=QuantType.fake_quant)

        transformer_kwargs = self.transformer_full_kwargs
        transformer_kwargs['hparams'] = hparams
        input_shape = (2, 4)
        target_shape = input_shape
        model, init_state = self.init_model(transformer_kwargs)
        hlo_proto = hlo_utils.load_hlo_proto_from_model(
            model, init_state, [input_shape, target_shape])
        floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor')

        embedding_floor_ops = self._num_embedding_floors(
            (embedding_weight_prec is not None),
            (logits_inputs_prec is not None))

        self.assertEqual(floor_count, embedding_floor_ops)
  def test_load_hlo_proto_from_model_and_count_ops(self):
    input_shapes = [(1, 2)]
    # with nn.stateful() as init_state:
    test_model = self.TestModelWith2DenseLayers()
    init_state = test_model.init(
        random.PRNGKey(0), *[jnp.ones(shape) for shape in input_shapes])

    hlo_proto = hlo_utils.load_hlo_proto_from_model(test_model, init_state,
                                                    input_shapes)
    count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, ops_regex=r'dot')
    self.assertEqual(count, 2)
    def _create_hlo_from_resnet_hparams(self, hparams, input_shape):
        """Create an HLO representation from ResNet model and input_shape."""

        # Create model
        rng = random.PRNGKey(0)
        model, init_state = create_model(rng,
                                         input_shape[0],
                                         input_shape[1],
                                         jnp.float32,
                                         hparams.model_hparams,
                                         train=False)

        # Create HLO
        hlo_proto = hlo_utils.load_hlo_proto_from_model(
            model, init_state, [input_shape])

        del model, init_state
        return hlo_proto
Exemplo n.º 7
0
 def test_count_floor_ops(self, base_config_filename, expected_floor_count):
     hparams = hparams_utils.load_hparams_from_config_dict(
         hparams_config.TrainingHParams, models.ResNet.HParams,
         base_config_filename.get_config())
     input_shape = (32, 16, 16, 3)
     model, init_state = create_model(self.rng_key,
                                      input_shape[0],
                                      input_shape[1],
                                      jnp.float32,
                                      hparams.model_hparams,
                                      train=False)
     hlo_proto = hlo_utils.load_hlo_proto_from_model(
         model, init_state, [input_shape])
     floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor')
     self.assertEqual(floor_count, expected_floor_count)
     # Expected floor count
     expected_floor_count_from_hparams = 0
     expected_floor_count_from_hparams += self._num_dense_floors(
         hparams.model_hparams)
     expected_floor_count_from_hparams += self._num_conv_floors(
         hparams.model_hparams)
     self.assertEqual(floor_count, expected_floor_count_from_hparams)
Exemplo n.º 8
0
def estimate_compute_and_memory_cost(image_size, model_dir, hparams):
    """Estimate compute and memory cost of model."""
    FLAGS.metadata_enabled = True
    input_shape = (1, image_size, image_size, 3)
    model, init_state = imagenet_train_utils.create_model(
        jax.random.PRNGKey(0),
        input_shape[0],
        input_shape[1],
        jnp.float32,
        hparams.model_hparams,
        train=False)
    hlo_proto = hlo_utils.load_hlo_proto_from_model(model, init_state,
                                                    [input_shape])
    del model, init_state
    cost_dict = compute_cost_utils.estimate_compute_cost(hlo_proto)
    memory_cost_dict = compute_cost_utils.estimate_memory_cost(hlo_proto)
    cost_dict.update(memory_cost_dict)
    FLAGS.metadata_enabled = False

    path = os.path.join(model_dir, COMPUTE_MEMORY_COST_FILENAME)
    with open(path, 'w') as file:
        json.dump(cost_dict, file, indent=2)
    logging.info('Estimated compute and memory costs and wrote to file')
Exemplo n.º 9
0
    def test_annotation_only_changes_hlo_metadata_dense(
            self, weight_prec, acts_prec):
        FLAGS.metadata_enabled = False
        quant_act = quantization.QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            prec=acts_prec,
            bounds=1.0,
            half_shift=False)
        input_shape = (1, 16)
        module_no_annotation = aqt_flax_layers.DenseAqt(
            features=4,
            use_bias=False,
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            paxis_name='batch',
            train=False,
            hparams=aqt_flax_layers.DenseAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_quant_granularity=quant_config.QuantGranularity.
                per_channel,
                weight_half_shift=False),
            dtype=jnp.float32)

        init_state = module_no_annotation.init(self.rng_key,
                                               jnp.ones(
                                                   input_shape, jnp.float32),
                                               padding_mask=None)
        output_no_annotation = module_no_annotation.apply(
            init_state, jnp.ones(input_shape), padding_mask=None)

        hlo_no_annotation = hlo_utils.load_hlo_proto_from_model(
            module_no_annotation, init_state, [input_shape], padding_mask=None)
        del init_state

        FLAGS.metadata_enabled = True
        module_w_annotation = aqt_flax_layers.DenseAqt(
            features=4,
            use_bias=False,
            paxis_name='batch',
            train=False,
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            dtype=jnp.float32,
            hparams=aqt_flax_layers.DenseAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_quant_granularity=quant_config.QuantGranularity.
                per_channel,
                weight_half_shift=False),
        )

        init_state = module_w_annotation.init(self.rng_key,
                                              jnp.ones(input_shape,
                                                       jnp.float32),
                                              padding_mask=None)
        output_w_annotation = module_w_annotation.apply(init_state,
                                                        jnp.ones(input_shape),
                                                        padding_mask=None)

        hlo_w_annotation = hlo_utils.load_hlo_proto_from_model(
            module_w_annotation, init_state, [input_shape], padding_mask=None)
        del init_state

        onp.testing.assert_array_equal(output_no_annotation,
                                       output_w_annotation)
        self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
Exemplo n.º 10
0
    def test_number_of_floor_ops(
            self, num_layers, mlp_weight_prec, mlp_pos_inputs_prec,
            mlp_signed_inputs_prec, attention_kqv_inputs_prec,
            attention_out_inputs_prec, embedding_weight_prec,
            attention_weight_prec, logits_inputs_prec,
            attention_act_q_inputs_prec, attention_act_k_inputs_prec,
            attention_act_probs_inputs_prec, attention_act_v_inputs_prec):
        # Counts number of floor ops as a proxy for quantization ops.
        act_fixed_clip_bound = 3.0
        hparams = training_hparams_generator_lib.create_base_transformer_hparams(
            mlp_weight_prec=mlp_weight_prec,
            embedding_weight_prec=embedding_weight_prec,
            attention_weight_prec=attention_weight_prec,
            mlp_pos_inputs_prec=mlp_pos_inputs_prec,
            mlp_pos_inputs_hyper=act_fixed_clip_bound,
            mlp_signed_inputs_prec=mlp_signed_inputs_prec,
            mlp_signed_inputs_hyper=act_fixed_clip_bound,
            attention_kqv_inputs_prec=attention_kqv_inputs_prec,
            attention_kqv_inputs_hyper=act_fixed_clip_bound,
            attention_out_inputs_prec=attention_out_inputs_prec,
            attention_out_inputs_hyper=act_fixed_clip_bound,
            logits_inputs_prec=logits_inputs_prec,
            logits_inputs_hyper=act_fixed_clip_bound,
            logits_via_embeddings=True,
            attention_act_q_inputs_prec=attention_act_q_inputs_prec,
            attention_act_q_inputs_hyper=act_fixed_clip_bound,
            attention_act_k_inputs_prec=attention_act_k_inputs_prec,
            attention_act_k_inputs_hyper=act_fixed_clip_bound,
            attention_act_probs_inputs_prec=attention_act_probs_inputs_prec,
            attention_act_v_inputs_prec=attention_act_v_inputs_prec,
            attention_act_v_inputs_hyper=act_fixed_clip_bound,
            num_layers=num_layers,
            emb_dim=5,
            num_heads=8,
            qkv_dim=8,
            mlp_dim=7,
            quant_type=QuantType.fake_quant)

        transformer_kwargs = self.transformer_full_kwargs
        transformer_kwargs['hparams'] = hparams
        input_shape = (2, 4)
        target_shape = input_shape
        model, init_state = self.init_model(transformer_kwargs)
        hlo_proto = hlo_utils.load_hlo_proto_from_model(
            model, init_state, [input_shape, target_shape])
        floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor')

        mlp_floors_per_layer = self._num_mlp_floors(
            (mlp_weight_prec is not None), (mlp_pos_inputs_prec is not None),
            (mlp_signed_inputs_prec is not None))

        attention_floors_per_layer = self._num_attention_floors(
            (attention_weight_prec is not None),
            (attention_kqv_inputs_prec is not None),
            (attention_out_inputs_prec is not None),
            (attention_act_q_inputs_prec is not None),
            (attention_act_k_inputs_prec is not None),
            (attention_act_probs_inputs_prec is not None),
            (attention_act_v_inputs_prec is not None))

        embedding_floors = self._num_embedding_floors(
            (embedding_weight_prec is not None),
            (logits_inputs_prec is not None))

        expected_floor_count = num_layers * (
            mlp_floors_per_layer +
            attention_floors_per_layer) + embedding_floors
        self.assertEqual(floor_count, expected_floor_count)