def test_auto_acts_parameter(self): # If use_auto_acts is False, then the bounds should be a single scalar that # specifies the fixed bound; 'None' by default. config = config_schema_utils.get_base_config(use_auto_acts=False) self.assertIsNone(config.quant_act.bounds) # If use_auto_acts is True, it should have the same structure as the # GetBounds.Hyper dataclass. config = config_schema_utils.get_base_config(use_auto_acts=True) self.assertIn('initial_bound', config.quant_act.bounds) # Because the config dict is locked, it shouldn't be possible to change it # back to fixed bounds if it was created with use_auto_acts=True. with self.assertRaises(TypeError): config.quant_act.bounds = 1.0
def test_precision_propagates(self, use_auto_acts): config = config_schema_utils.get_base_config(use_auto_acts) # Set the global precision to 4 bits. config.prec = 4 # Test that this sets the weight and activation to 4 as well. self.assertEqual(config.weight_prec, 4) self.assertEqual(config.quant_act.prec, 4)
def get_base_config(use_auto_acts): """Base ConfigDict for resnet, does not yet have fields for individual layers.""" base_config = config_schema_utils.get_base_config(use_auto_acts, fp_quant=False) base_config.update({ "base_learning_rate": float_ph(), "momentum": float_ph(), "model_hparams": {}, "act_function": str_ph(), # add a new field that controls act function "shortcut_ch_shrink_method": str_ph(), "shortcut_ch_expand_method": str_ph(), "shortcut_spatial_method": str_ph(), "lr_scheduler": { "warmup_epochs": int_ph(), "cooldown_epochs": int_ph(), "scheduler": str_ph(), # "cosine", "linear", or "step" lr decay "num_epochs": int_ph(), }, "optimizer": str_ph(), "adam": { "beta1": float_ph(), "beta2": float_ph() }, "early_stop_steps": int_ph(), }) if use_auto_acts: # config_schema_utils is shared by wmt. To not make other code libraries # affected, add the new bound coefficients here. base_config.quant_act.bounds.update({ "fixed_bound": float_ph(), "cams_coeff": float_ph(), "cams_stddev_coeff": float_ph(), "mean_of_max_coeff": float_ph(), "use_old_code": bool_ph(), }) base_config.dense_layer = config_schema_utils.get_dense_config(base_config) # TODO(b/179063860): The input distribution is an intrinsic model # property and shouldn't be part of the model configuration. Update # the hparam dataclasses to eliminate the input_distribution field and # then delete this. base_config.dense_layer.quant_act.input_distribution = "positive" base_config.conv = config_schema_utils.get_conv_config(base_config) base_config.residual = get_residual_config(base_config) # make the activation function in a residual block consistent with the global # option config_schema_utils.set_default_reference(base_config.residual, base_config, field=[ "act_function", "shortcut_ch_shrink_method", "shortcut_ch_expand_method", "shortcut_spatial_method" ]) return base_config
def test_precision_propagates(self, use_auto_acts): config = config_schema_utils.get_base_config(use_auto_acts, fp_quant=False) # Set the global precision to 4 bits. config.prec = 4 # Set the global half_shift flag to False config.half_shift = False # Test that this sets the weight and activation to 4 as well. self.assertEqual(config.weight_prec, 4) self.assertEqual(config.quant_act.prec, 4) # Test that this sets the weight_half_shift and act half_shift to False self.assertEqual(config.weight_half_shift, False) self.assertEqual(config.quant_act.half_shift, False)
def get_base_config(use_auto_acts): """Base ConfigDict for resnet, does not yet have fields for individual layers.""" base_config = config_schema_utils.get_base_config(use_auto_acts) base_config.update({ "base_learning_rate": float_ph(), "momentum": float_ph(), "model_hparams": {}, }) base_config.dense_layer = config_schema_utils.get_dense_config(base_config) # TODO(b/179063860): The input distribution is an intrinsic model # property and shouldn't be part of the model configuration. Update # the hparam dataclasses to eliminate the input_distribution field and # then delete this. base_config.dense_layer.quant_act.input_distribution = "positive" base_config.conv = config_schema_utils.get_conv_config(base_config) base_config.residual = get_residual_config(base_config) return base_config
def get_wmt_base_config(use_auto_acts, fp_quant): """Return a base ConfigDict which does not yet have fields for individual layers.""" base_config = config_schema_utils.get_base_config(use_auto_acts, fp_quant=fp_quant) base_config.update({ "learning_rate_schedule": { "factors": str_ph(), "base_learning_rate": float_ph(), "warmup_steps": int_ph(), "decay_factor": float_ph(), "steps_per_decay": int_ph(), "steps_per_cycle": int_ph(), }, "per_host_batch_size": int_ph(), "num_train_steps": int_ph(), "beta1": float_ph(), "beta2": float_ph(), "eps": float_ph(), "random_seed": int_ph(), "hardware_rng": bool_ph(), "weight_outlier_regularization": float_ph(), "weight_outlier_regularization_regex": str_ph(), "prefer_int8_to_int32_dot": bool_ph(), "model_hparams": { "emb_dim": int_ph(), "num_heads": int_ph(), "qkv_dim": int_ph(), "mlp_dim": int_ph(), "share_embeddings": bool_ph(), "logits_via_embedding": bool_ph(), }, }) base_config.dense = config_schema_utils.get_dense_config(base_config) base_config.mlp_block = get_mlp_block_config(base_config) base_config.embedding = get_embedding_config(base_config) base_config.attention = get_attention_config(base_config) return base_config
def test_fp_precision_propagates(self, use_auto_acts): config = config_schema_utils.get_base_config(use_auto_acts, fp_quant=True) config.prec.is_scaled = False # Set the global precision to 4 bits. config.prec.fp_spec.update({ 'exp_min': -3, 'exp_max': 5, 'sig_bits': 2 }) expected_prec_dict = { 'is_scaled': False, 'fp_spec': { 'exp_min': -3, 'exp_max': 5, 'sig_bits': 2 } } # Test that this sets the weight and activation to 4 as well. self.assertEqual(config.weight_prec.to_dict(), expected_prec_dict) self.assertEqual(config.quant_act.prec.to_dict(), expected_prec_dict)
def test_schema_matches_expected(self, use_auto_acts, fp_quant): # This tests that the schema of the configdict returned by 'base_config', # once all references are resolved, matches an expected schema. 'Schema' # here means the names and structure of fields at each level of the # configuration hierarchy. A value of 'None' in the expected schemas defined # below indicates a real configuration would have a concrete scalar value # there. if fp_quant: prec = { 'fp_spec': { 'exp_min': None, 'exp_max': None, 'sig_bits': None, }, 'is_scaled': None, } else: prec = None if use_auto_acts: quant_act_schema = { 'bounds': { 'initial_bound': None, 'stddev_coeff': None, 'absdev_coeff': None, 'mix_coeff': None, 'reset_stats': None, 'ema_coeff': None, 'use_cams': None, 'exclude_zeros': None, 'use_mean_of_max': None, 'granularity': None }, 'input_distribution': None, 'prec': prec, 'half_shift': None, } else: quant_act_schema = { 'bounds': None, 'input_distribution': None, 'prec': prec, 'half_shift': None, } expected_top_level_schema = { 'metadata': { 'description': None, 'hyper_str': None }, 'weight_decay': None, 'activation_bound_update_freq': None, 'activation_bound_start_step': None, 'prec': prec, 'half_shift': None, 'weight_prec': prec, 'weight_half_shift': None, 'quant_type': None, 'quant_act': quant_act_schema, 'weight_quant_granularity': None, } config = config_schema_utils.get_base_config( use_auto_acts=use_auto_acts, fp_quant=fp_quant) # This round-trip conversion from JSON forces all references to resolve to # concrete values. config_reified = json.loads(config.to_json()) # This test is not interested in checking the specific values of fields in # the configuration, but only that the schema of the hierarchies # are the same. Thus we all set the value of leaf nodes in the config to # 'None' before checking that the actual and expected configuration # structures are the same. def set_leaves_to_none(config): # We are at an intermediate node in the tree-structured input, which could # either be in the form of a dictionary or a list of other nodes in the # tree. if isinstance(config, dict): return { key: set_leaves_to_none(value) for key, value in config.items() } elif isinstance(config, list): return [set_leaves_to_none(value) for value in config] # We are at a leaf node in the tree-structured input. else: return None self.assertSameStructure(set_leaves_to_none(config_reified), expected_top_level_schema)