def test_load_hlo_proto_from_jax_fn_and_count_ops(self, fn, fn_args, ops_regex, exp_count): hlo_proto = hlo_utils.load_hlo_proto_from_jax_fn( fn, *fn_args) count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, ops_regex=ops_regex) self.assertEqual(count, exp_count)
def test_number_of_floor_ops_embedding(self, num_layers, embedding_weight_prec, logits_inputs_prec, logits_inputs_hyper_is_float, logits_via_embeddings): # Counts number of floor ops as a proxy for quantization ops. if logits_inputs_hyper_is_float: logits_inputs_hyper = 6.0 else: logits_inputs_hyper = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=3.0, absdev_coeff=2.0, mix_coeff=0.5, granularity=quant_config.QuantGranularity.per_tensor) hparams = training_hparams_generator_lib.create_base_transformer_hparams( mlp_weight_prec=None, embedding_weight_prec=embedding_weight_prec, attention_weight_prec=None, mlp_pos_inputs_prec=None, mlp_pos_inputs_hyper=None, mlp_signed_inputs_prec=None, mlp_signed_inputs_hyper=None, attention_kqv_inputs_prec=None, attention_kqv_inputs_hyper=None, attention_out_inputs_prec=None, attention_out_inputs_hyper=None, logits_inputs_prec=logits_inputs_prec, logits_inputs_hyper=logits_inputs_hyper, logits_via_embeddings=logits_via_embeddings, attention_act_q_inputs_prec=None, attention_act_q_inputs_hyper=None, attention_act_k_inputs_prec=None, attention_act_k_inputs_hyper=None, attention_act_probs_inputs_prec=None, attention_act_v_inputs_prec=None, attention_act_v_inputs_hyper=None, num_layers=num_layers, emb_dim=5, num_heads=8, qkv_dim=8, mlp_dim=7, quant_type=QuantType.fake_quant) transformer_kwargs = self.transformer_full_kwargs transformer_kwargs['hparams'] = hparams input_shape = (2, 4) target_shape = input_shape model, init_state = self.init_model(transformer_kwargs) hlo_proto = hlo_utils.load_hlo_proto_from_model( model, init_state, [input_shape, target_shape]) floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor') embedding_floor_ops = self._num_embedding_floors( (embedding_weight_prec is not None), (logits_inputs_prec is not None)) self.assertEqual(floor_count, embedding_floor_ops)
def test_load_hlo_proto_from_model_and_count_ops(self): input_shapes = [(1, 2)] # with nn.stateful() as init_state: test_model = self.TestModelWith2DenseLayers() init_state = test_model.init( random.PRNGKey(0), *[jnp.ones(shape) for shape in input_shapes]) hlo_proto = hlo_utils.load_hlo_proto_from_model(test_model, init_state, input_shapes) count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, ops_regex=r'dot') self.assertEqual(count, 2)
def test_count_floor_ops(self, base_config_filename, expected_floor_count): hparams = hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, base_config_filename.get_config()) input_shape = (32, 16, 16, 3) model, init_state = create_model(self.rng_key, input_shape[0], input_shape[1], jnp.float32, hparams.model_hparams, train=False) hlo_proto = hlo_utils.load_hlo_proto_from_model( model, init_state, [input_shape]) floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor') self.assertEqual(floor_count, expected_floor_count) # Expected floor count expected_floor_count_from_hparams = 0 expected_floor_count_from_hparams += self._num_dense_floors( hparams.model_hparams) expected_floor_count_from_hparams += self._num_conv_floors( hparams.model_hparams) self.assertEqual(floor_count, expected_floor_count_from_hparams)
def test_number_of_floor_ops( self, num_layers, mlp_weight_prec, mlp_pos_inputs_prec, mlp_signed_inputs_prec, attention_kqv_inputs_prec, attention_out_inputs_prec, embedding_weight_prec, attention_weight_prec, logits_inputs_prec, attention_act_q_inputs_prec, attention_act_k_inputs_prec, attention_act_probs_inputs_prec, attention_act_v_inputs_prec): # Counts number of floor ops as a proxy for quantization ops. act_fixed_clip_bound = 3.0 hparams = training_hparams_generator_lib.create_base_transformer_hparams( mlp_weight_prec=mlp_weight_prec, embedding_weight_prec=embedding_weight_prec, attention_weight_prec=attention_weight_prec, mlp_pos_inputs_prec=mlp_pos_inputs_prec, mlp_pos_inputs_hyper=act_fixed_clip_bound, mlp_signed_inputs_prec=mlp_signed_inputs_prec, mlp_signed_inputs_hyper=act_fixed_clip_bound, attention_kqv_inputs_prec=attention_kqv_inputs_prec, attention_kqv_inputs_hyper=act_fixed_clip_bound, attention_out_inputs_prec=attention_out_inputs_prec, attention_out_inputs_hyper=act_fixed_clip_bound, logits_inputs_prec=logits_inputs_prec, logits_inputs_hyper=act_fixed_clip_bound, logits_via_embeddings=True, attention_act_q_inputs_prec=attention_act_q_inputs_prec, attention_act_q_inputs_hyper=act_fixed_clip_bound, attention_act_k_inputs_prec=attention_act_k_inputs_prec, attention_act_k_inputs_hyper=act_fixed_clip_bound, attention_act_probs_inputs_prec=attention_act_probs_inputs_prec, attention_act_v_inputs_prec=attention_act_v_inputs_prec, attention_act_v_inputs_hyper=act_fixed_clip_bound, num_layers=num_layers, emb_dim=5, num_heads=8, qkv_dim=8, mlp_dim=7, quant_type=QuantType.fake_quant) transformer_kwargs = self.transformer_full_kwargs transformer_kwargs['hparams'] = hparams input_shape = (2, 4) target_shape = input_shape model, init_state = self.init_model(transformer_kwargs) hlo_proto = hlo_utils.load_hlo_proto_from_model( model, init_state, [input_shape, target_shape]) floor_count = hlo_utils.count_ops_in_hlo_proto(hlo_proto, r'floor') mlp_floors_per_layer = self._num_mlp_floors( (mlp_weight_prec is not None), (mlp_pos_inputs_prec is not None), (mlp_signed_inputs_prec is not None)) attention_floors_per_layer = self._num_attention_floors( (attention_weight_prec is not None), (attention_kqv_inputs_prec is not None), (attention_out_inputs_prec is not None), (attention_act_q_inputs_prec is not None), (attention_act_k_inputs_prec is not None), (attention_act_probs_inputs_prec is not None), (attention_act_v_inputs_prec is not None)) embedding_floors = self._num_embedding_floors( (embedding_weight_prec is not None), (logits_inputs_prec is not None)) expected_floor_count = num_layers * ( mlp_floors_per_layer + attention_floors_per_layer) + embedding_floors self.assertEqual(floor_count, expected_floor_count)