def test_fp_precision_param(self):
        config = config_schema.get_config(n_layers=3,
        prec_dict = {
            'is_scaled': False,
            'fp_spec': {
                'exp_min': -3,
                'exp_max': 5,
                'sig_bits': 2
        # Set the global precision to 4 bits.

        # Test that this sets the weight and activation prec as well.
        self.assertEqual(config.weight_prec.to_dict(), prec_dict)
        self.assertEqual(config.quant_act.prec.to_dict(), prec_dict)

        # Test that propagates all the way down to the weight precision of layer
        # types and individual layers. As an example of an individual layer, we take
        # the dense1 matmul of the second block of the decoder.
        dense1_block2 = config.model_hparams.decoder.encoder_decoder_1d_blocks[
        # Meanwhile, 'dense1' represents the generic configuration of all dense1
        # layers throughout the model.
        dense1 = config.mlp_block.dense_1
        self.assertEqual(dense1.weight_prec.to_dict(), prec_dict)
        self.assertEqual(dense1_block2.weight_prec.to_dict(), prec_dict)
 def test_softmax_config(self, quantized_reductions):
     base_config = config_schema.get_config(n_layers=3,
     softmax_config = config_schema.get_softmax_config(
         quantized=True, quantized_reductions=quantized_reductions)
         'exp_min': -3,
         'exp_max': 5,
         'sig_bits': 2
     if quantized_reductions:
             'exp_min': 1,
             'exp_max': 4,
             'sig_bits': 6
     new_config = config_schema.set_global_softmax_config(
         base_config=base_config, softmax_config=softmax_config)
     quant_hparams = new_config.model_hparams.decoder.encoder_decoder_1d_blocks[
     self.assertEqual(quant_hparams.prec.to_dict(), {
         'exp_min': -3,
         'exp_max': 5,
         'sig_bits': 2
     if quantized_reductions:
         self.assertEqual(quant_hparams.reduction_prec.to_dict(), {
             'exp_min': 1,
             'exp_max': 4,
             'sig_bits': 6
 def test_n_layers_parameter(self, n_layers):
     config = config_schema.get_config(n_layers=n_layers,
    def test_auto_acts_parameter(self):
        # If use_auto_acts is False, then the bounds should be a single scalar that
        # specifies the fixed bound; 'None' by default.
        config = config_schema.get_config(n_layers=3,
        # If use_auto_acts is True, it should have the same structure as the
        # GetBounds.Hyper dataclass.
        config = config_schema.get_config(n_layers=3,
        self.assertIn('initial_bound', config.quant_act.bounds)

        # Because the config dict is locked, it shouldn't be possible to change it
        # back to fixed bounds if it was created with use_auto_acts=True.
        with self.assertRaises(TypeError):
            config.quant_act.bounds = 1.0
    def test_precision_propagates(self):
        config = config_schema.get_config(n_layers=3,

        # Set the global precision to 4 bits.
        config.prec = 4
        # Test that this sets the weight and activation to 4 as well.
        self.assertEqual(config.weight_prec, 4)
        self.assertEqual(config.quant_act.prec, 4)
        # Test that propagates all the way down to the weight precision of layer
        # types and individual layers. As an example of an individual layer, we take
        # the dense1 matmul of the second block of the decoder.
        dense1_block2 = config.model_hparams.decoder.encoder_decoder_1d_blocks[
        # Meanwhile, 'dense1' represents the generic configuration of all dense1
        # layers throughout the model.
        dense1 = config.mlp_block.dense_1
        self.assertEqual(dense1.weight_prec, 4)
        self.assertEqual(dense1_block2.weight_prec, 4)

        # Test if we take the same config instance and alter the global precision to
        # 8, it automatically propagates to individual layers.
        config.prec = 8
        self.assertEqual(dense1.weight_prec, 8)
        self.assertEqual(dense1_block2.weight_prec, 8)

        # Test that the precision can be overridden for a specific layer type. We
        # want to verify that the change doesn't back-propagate back to the global
        # precision field but does propagate down to individual layers of that layer
        # type. We only want changes to fields to automatically propagate down the
        # parameter hierarchy, not up.
        dense1.weight_prec = 2
        self.assertEqual(dense1.weight_prec, 2)
        self.assertEqual(dense1_block2.weight_prec, 2)
        self.assertEqual(config.prec, 8)

        # Now update the precision for just a specific layer and check that it
        # doesn't propagate upwards.
        dense1_block2.weight_prec = 1
        self.assertEqual(dense1_block2.weight_prec, 1)
        self.assertEqual(dense1.weight_prec, 2)
        self.assertEqual(config.prec, 8)
def get_base_config(n_layers, use_auto_acts, fp_quant):
    """Returns config that sets hyperparameters common to all quantization targets.

  Fields in that config can then be overridden to customize a configuration.

  Note that two hyperparameters, the number of layers and whether to
  automatically find clipping bounds for activations, have to be specified in
  advance as keyword arguments to this function instead of being overridden in
  the returned configdict. That is because these parameters affect the name and
  number of fields in the configdict instance, which can't be changed after
  creation: there will be one set of overridable parameters per layer in the
  configdict, and the field names in the 'quant_act' fields change depending on

    n_layers: Number of layers in the encoder and decoder (eg, n_layers=3 mean
      three encoder layers and three decoder layers).
    use_auto_acts: Whether to use automatic bounds calculation for activations
      (True) or fixed bounds.
    fp_quant: Whether to use floating point quantization. Defaults to False for
      integer quantization.

    A ConfigDict instance suitable for WMT training.
    config = config_schema.get_config(use_auto_acts=use_auto_acts,
    config.half_shift = False
        "learning_rate_schedule": {
            "factors": "constant * linear_warmup * rsqrt_decay",
            "base_learning_rate": 0.0625,
            "warmup_steps": 1000,
            "decay_factor": 0.5,
            "steps_per_decay": 20000,
            "steps_per_cycle": 100000,
        "per_host_batch_size": 256,
        "num_train_steps": 200000,
        "weight_decay": 0.25,
        "beta1": 0.9,
        "beta2": 0.98,
        "eps": 1e-9,
        "random_seed": 0,
        "hardware_rng": True,
        "activation_bound_update_freq": -1,
        "activation_bound_start_step": -1,
        "weight_outlier_regularization": 0.0,
        "prefer_int8_to_int32_dot": True,
        "model_hparams": {
            "emb_dim": 1024,
            "num_heads": 16,
            "qkv_dim": 1024,
            "mlp_dim": 4096,
            "share_embeddings": True,
            "logits_via_embedding": True,
        "weight_outlier_regularization_regex": "^.*kernel$",
        "weight_quant_granularity": "per_channel"
    if not fp_quant:
        config.prec = None
        config.quant_type = "aqt"
        config.prec.is_scaled = False
        config.quant_type = "fake_quant"
    layernorm_config = config_schema.get_layer_norm_config(
        quantized=False, quantized_reductions=False)
    config = config_schema.set_global_layer_norm_config(
        config, layernorm_config)
    return config
    def test_schema_matches_expected(self, n_layers):
        # This tests that the schema of the configdict returned by 'config_schema',
        # once all references are resolved, matches an expected schema. 'Schema'
        # here means the names and structure of fields at each level of the
        # configuration hierarchy. A value of 'None' in the expected schemas defined
        # below indicates a real configuration would have a concrete scalar value
        # there.

        quant_act_schema = {
            'bounds': {
                'initial_bound': None,
                'stddev_coeff': None,
                'absdev_coeff': None,
                'mix_coeff': None,
                'reset_stats': None,
                'ema_coeff': None,
                'use_cams': None,
                'exclude_zeros': None,
                'use_mean_of_max': None,
                'granularity': None
            'input_distribution': None,
            'prec': None,
            'half_shift': None,

        dense_schema = {
            'weight_prec': None,
            'weight_quant_granularity': None,
            'quant_type': None,
            'quant_act': quant_act_schema,
            'weight_half_shift': None,

        embedding_schema = {
            'weight_prec': None,
            'quant_type': None,
            'quant_act': quant_act_schema,
            'weight_half_shift': None,

        mlp_block_schema = {
            'dense_1': dense_schema,
            'dense_2': dense_schema,

        fp_schema = {'exp_min': None, 'exp_max': None, 'sig_bits': None}

        layernorm_schema = {
            'quant_hparams': {
                'prec': fp_schema,
                'reduction_prec': fp_schema

        attention_schema = {
            'dense_kqv': dense_schema,
            'dense_out': dense_schema,
            'quant_type': None,
            'quant_act': quant_act_schema,
            'attn_acts': {
                'quant_type': None,
                'attn_act_q': quant_act_schema,
                'attn_act_k': quant_act_schema,
                'attn_act_v': quant_act_schema,
                'attn_act_probs': {
                    'bounds': None,
                    'input_distribution': None,
                    'prec': None,
                    'half_shift': None,

        expected_top_level_schema = {
            'metadata': {
                'description': None,
                'hyper_str': None
            'learning_rate_schedule': {
                'factors': None,
                'base_learning_rate': None,
                'warmup_steps': None,
                'decay_factor': None,
                'steps_per_decay': None,
                'steps_per_cycle': None,
            'per_host_batch_size': None,
            'num_train_steps': None,
            'weight_decay': None,
            'beta1': None,
            'beta2': None,
            'eps': None,
            'random_seed': None,
            'hardware_rng': None,
            'activation_bound_update_freq': None,
            'activation_bound_start_step': None,
            'weight_outlier_regularization': None,
            'weight_outlier_regularization_regex': None,
            'prefer_int8_to_int32_dot': None,
            'prec': None,
            'half_shift': None,
            'weight_prec': None,
            'weight_half_shift': None,
            'quant_type': None,
            'quant_act': quant_act_schema,
            'weight_quant_granularity': None,
            'dense': dense_schema,
            'embedding': embedding_schema,
            'mlp_block': mlp_block_schema,
            'attention': attention_schema,
            'model_hparams': {
                'emb_dim': None,
                'num_heads': None,
                'qkv_dim': None,
                'mlp_dim': None,
                'share_embeddings': None,
                'logits_via_embedding': None,
                'encoder': {
                    'encoder_1d_blocks': [{
                        'mlp_block': mlp_block_schema,
                        'attention': attention_schema,
                        'layer_norm': layernorm_schema
                    }] * n_layers,
                'decoder': {
                    'encoder_decoder_1d_blocks': [
                            'mlp_block': mlp_block_schema,
                            'self_attention': attention_schema,
                            'enc_dec_attention': attention_schema,
                            'layer_norm': layernorm_schema
                    ] * n_layers,

        config = config_schema.get_config(n_layers=n_layers,
        layer_norm_config = config_schema.get_layer_norm_config(
            quantized=True, quantized_reductions=True)
        config = config_schema.set_global_layer_norm_config(
            config, layer_norm_config)
        # This round-trip conversion from JSON forces all references to resolve to
        # concrete values.
        config_reified = json.loads(config.to_json())

        # This test is not interested in checking the specific values of fields in
        # the configuration, but only that the schema of the hierarchies
        # are the same. Thus we all set the value of leaf nodes in the config to
        # 'None' before checking that the actual and expected configuration
        # structures are the same.
        def set_leaves_to_none(config):
            # We are at an intermediate node in the tree-structured input, which could
            # either be in the form of a dictionary or a list of other nodes in the
            # tree.
            if isinstance(config, dict):
                return {
                    key: set_leaves_to_none(value)
                    for key, value in config.items()
            elif isinstance(config, list):
                return [set_leaves_to_none(value) for value in config]

            # We are at a leaf node in the tree-structured input.
                return None
