Esempio n. 1
0
    def test_auto_acts_parameter(self):
        # If use_auto_acts is False, then the bounds should be a single scalar that
        # specifies the fixed bound; 'None' by default.
        config = config_schema_utils.get_base_config(use_auto_acts=False)
        self.assertIsNone(config.quant_act.bounds)
        # If use_auto_acts is True, it should have the same structure as the
        # GetBounds.Hyper dataclass.
        config = config_schema_utils.get_base_config(use_auto_acts=True)
        self.assertIn('initial_bound', config.quant_act.bounds)

        # Because the config dict is locked, it shouldn't be possible to change it
        # back to fixed bounds if it was created with use_auto_acts=True.
        with self.assertRaises(TypeError):
            config.quant_act.bounds = 1.0
Esempio n. 2
0
    def test_precision_propagates(self, use_auto_acts):
        config = config_schema_utils.get_base_config(use_auto_acts)

        # Set the global precision to 4 bits.
        config.prec = 4
        # Test that this sets the weight and activation to 4 as well.
        self.assertEqual(config.weight_prec, 4)
        self.assertEqual(config.quant_act.prec, 4)
def get_base_config(use_auto_acts):
    """Base ConfigDict for resnet, does not yet have fields for individual layers."""
    base_config = config_schema_utils.get_base_config(use_auto_acts,
                                                      fp_quant=False)
    base_config.update({
        "base_learning_rate": float_ph(),
        "momentum": float_ph(),
        "model_hparams": {},
        "act_function": str_ph(),  # add a new field that controls act function
        "shortcut_ch_shrink_method": str_ph(),
        "shortcut_ch_expand_method": str_ph(),
        "shortcut_spatial_method": str_ph(),
        "lr_scheduler": {
            "warmup_epochs": int_ph(),
            "cooldown_epochs": int_ph(),
            "scheduler": str_ph(),  # "cosine", "linear", or "step" lr decay
            "num_epochs": int_ph(),
        },
        "optimizer": str_ph(),
        "adam": {
            "beta1": float_ph(),
            "beta2": float_ph()
        },
        "early_stop_steps": int_ph(),
    })
    if use_auto_acts:
        # config_schema_utils is shared by wmt. To not make other code libraries
        # affected, add the new bound coefficients here.
        base_config.quant_act.bounds.update({
            "fixed_bound": float_ph(),
            "cams_coeff": float_ph(),
            "cams_stddev_coeff": float_ph(),
            "mean_of_max_coeff": float_ph(),
            "use_old_code": bool_ph(),
        })

    base_config.dense_layer = config_schema_utils.get_dense_config(base_config)
    # TODO(b/179063860): The input distribution is an intrinsic model
    # property and shouldn't be part of the model configuration. Update
    # the hparam dataclasses to eliminate the input_distribution field and
    # then delete this.
    base_config.dense_layer.quant_act.input_distribution = "positive"
    base_config.conv = config_schema_utils.get_conv_config(base_config)
    base_config.residual = get_residual_config(base_config)
    # make the activation function in a residual block consistent with the global
    # option
    config_schema_utils.set_default_reference(base_config.residual,
                                              base_config,
                                              field=[
                                                  "act_function",
                                                  "shortcut_ch_shrink_method",
                                                  "shortcut_ch_expand_method",
                                                  "shortcut_spatial_method"
                                              ])
    return base_config
    def test_precision_propagates(self, use_auto_acts):
        config = config_schema_utils.get_base_config(use_auto_acts,
                                                     fp_quant=False)

        # Set the global precision to 4 bits.
        config.prec = 4
        # Set the global half_shift flag to False
        config.half_shift = False
        # Test that this sets the weight and activation to 4 as well.
        self.assertEqual(config.weight_prec, 4)
        self.assertEqual(config.quant_act.prec, 4)
        # Test that this sets the weight_half_shift and act half_shift to False
        self.assertEqual(config.weight_half_shift, False)
        self.assertEqual(config.quant_act.half_shift, False)
def get_base_config(use_auto_acts):
    """Base ConfigDict for resnet, does not yet have fields for individual layers."""
    base_config = config_schema_utils.get_base_config(use_auto_acts)
    base_config.update({
        "base_learning_rate": float_ph(),
        "momentum": float_ph(),
        "model_hparams": {},
    })

    base_config.dense_layer = config_schema_utils.get_dense_config(base_config)
    # TODO(b/179063860): The input distribution is an intrinsic model
    # property and shouldn't be part of the model configuration. Update
    # the hparam dataclasses to eliminate the input_distribution field and
    # then delete this.
    base_config.dense_layer.quant_act.input_distribution = "positive"
    base_config.conv = config_schema_utils.get_conv_config(base_config)
    base_config.residual = get_residual_config(base_config)
    return base_config
Esempio n. 6
0
def get_wmt_base_config(use_auto_acts, fp_quant):
    """Return a base ConfigDict which does not yet have fields for individual layers."""
    base_config = config_schema_utils.get_base_config(use_auto_acts,
                                                      fp_quant=fp_quant)
    base_config.update({
        "learning_rate_schedule": {
            "factors": str_ph(),
            "base_learning_rate": float_ph(),
            "warmup_steps": int_ph(),
            "decay_factor": float_ph(),
            "steps_per_decay": int_ph(),
            "steps_per_cycle": int_ph(),
        },
        "per_host_batch_size": int_ph(),
        "num_train_steps": int_ph(),
        "beta1": float_ph(),
        "beta2": float_ph(),
        "eps": float_ph(),
        "random_seed": int_ph(),
        "hardware_rng": bool_ph(),
        "weight_outlier_regularization": float_ph(),
        "weight_outlier_regularization_regex": str_ph(),
        "prefer_int8_to_int32_dot": bool_ph(),
        "model_hparams": {
            "emb_dim": int_ph(),
            "num_heads": int_ph(),
            "qkv_dim": int_ph(),
            "mlp_dim": int_ph(),
            "share_embeddings": bool_ph(),
            "logits_via_embedding": bool_ph(),
        },
    })

    base_config.dense = config_schema_utils.get_dense_config(base_config)
    base_config.mlp_block = get_mlp_block_config(base_config)
    base_config.embedding = get_embedding_config(base_config)
    base_config.attention = get_attention_config(base_config)

    return base_config
    def test_fp_precision_propagates(self, use_auto_acts):
        config = config_schema_utils.get_base_config(use_auto_acts,
                                                     fp_quant=True)

        config.prec.is_scaled = False
        # Set the global precision to 4 bits.
        config.prec.fp_spec.update({
            'exp_min': -3,
            'exp_max': 5,
            'sig_bits': 2
        })

        expected_prec_dict = {
            'is_scaled': False,
            'fp_spec': {
                'exp_min': -3,
                'exp_max': 5,
                'sig_bits': 2
            }
        }
        # Test that this sets the weight and activation to 4 as well.
        self.assertEqual(config.weight_prec.to_dict(), expected_prec_dict)
        self.assertEqual(config.quant_act.prec.to_dict(), expected_prec_dict)
    def test_schema_matches_expected(self, use_auto_acts, fp_quant):
        # This tests that the schema of the configdict returned by 'base_config',
        # once all references are resolved, matches an expected schema. 'Schema'
        # here means the names and structure of fields at each level of the
        # configuration hierarchy. A value of 'None' in the expected schemas defined
        # below indicates a real configuration would have a concrete scalar value
        # there.

        if fp_quant:
            prec = {
                'fp_spec': {
                    'exp_min': None,
                    'exp_max': None,
                    'sig_bits': None,
                },
                'is_scaled': None,
            }
        else:
            prec = None

        if use_auto_acts:
            quant_act_schema = {
                'bounds': {
                    'initial_bound': None,
                    'stddev_coeff': None,
                    'absdev_coeff': None,
                    'mix_coeff': None,
                    'reset_stats': None,
                    'ema_coeff': None,
                    'use_cams': None,
                    'exclude_zeros': None,
                    'use_mean_of_max': None,
                    'granularity': None
                },
                'input_distribution': None,
                'prec': prec,
                'half_shift': None,
            }
        else:
            quant_act_schema = {
                'bounds': None,
                'input_distribution': None,
                'prec': prec,
                'half_shift': None,
            }

        expected_top_level_schema = {
            'metadata': {
                'description': None,
                'hyper_str': None
            },
            'weight_decay': None,
            'activation_bound_update_freq': None,
            'activation_bound_start_step': None,
            'prec': prec,
            'half_shift': None,
            'weight_prec': prec,
            'weight_half_shift': None,
            'quant_type': None,
            'quant_act': quant_act_schema,
            'weight_quant_granularity': None,
        }

        config = config_schema_utils.get_base_config(
            use_auto_acts=use_auto_acts, fp_quant=fp_quant)
        # This round-trip conversion from JSON forces all references to resolve to
        # concrete values.
        config_reified = json.loads(config.to_json())

        # This test is not interested in checking the specific values of fields in
        # the configuration, but only that the schema of the hierarchies
        # are the same. Thus we all set the value of leaf nodes in the config to
        # 'None' before checking that the actual and expected configuration
        # structures are the same.
        def set_leaves_to_none(config):
            # We are at an intermediate node in the tree-structured input, which could
            # either be in the form of a dictionary or a list of other nodes in the
            # tree.
            if isinstance(config, dict):
                return {
                    key: set_leaves_to_none(value)
                    for key, value in config.items()
                }
            elif isinstance(config, list):
                return [set_leaves_to_none(value) for value in config]

            # We are at a leaf node in the tree-structured input.
            else:
                return None

        self.assertSameStructure(set_leaves_to_none(config_reified),
                                 expected_top_level_schema)