Esempio n. 1
0
def get_attention_config(parent_config):
    """Create a ConfigDict corresponding to aqt.flax_attention.MutliHeadDotProductAttentionAqt.HParams."""
    config = ml_collections.ConfigDict()
    config_schema_utils.set_default_reference(config,
                                              parent_config,
                                              ["dense_kqv", "dense_out"],
                                              parent_field="dense")

    config.attn_acts = ml_collections.ConfigDict({})

    config_schema_utils.set_default_reference(config, parent_config,
                                              ["quant_type", "quant_act"])
    config_schema_utils.set_default_reference(config.attn_acts, config,
                                              ["quant_type"])
    config_schema_utils.set_default_reference(
        config.attn_acts,
        config, ["attn_act_q", "attn_act_k", "attn_act_v"],
        parent_field="quant_act")
    config.attn_acts.attn_act_probs = ml_collections.ConfigDict({
        "input_distribution":
        "positive",
        "bounds":
        1.0,
        "half_shift":
        False,  # Set half_shift to false for positive distribution
    })
    config_schema_utils.set_default_reference(config.attn_acts.attn_act_probs,
                                              parent_config.quant_act, "prec")
    config.lock()
    return config
Esempio n. 2
0
def get_embedding_config(parent_config):
    """Create a ConfigDict corresponding to aqt.flax_layers.Embedding.HParams."""
    config = ml_collections.ConfigDict()
    config_schema_utils.set_default_reference(
        config, parent_config,
        ["weight_prec", "quant_type", "quant_act", "weight_half_shift"])
    config.lock()
    return config
Esempio n. 3
0
def get_mlp_block_config(parent_config):
    """Create a ConfigDict corresponding to wmt_mlperf.models.MlpBlock.HParams."""
    config = ml_collections.ConfigDict()
    config_schema_utils.set_default_reference(config,
                                              parent_config,
                                              ["dense_1", "dense_2"],
                                              parent_field="dense")
    config.dense_2.quant_act.input_distribution = "positive"
    config.lock()
    return config
def get_base_config(use_auto_acts):
    """Base ConfigDict for resnet, does not yet have fields for individual layers."""
    base_config = config_schema_utils.get_base_config(use_auto_acts,
                                                      fp_quant=False)
    base_config.update({
        "base_learning_rate": float_ph(),
        "momentum": float_ph(),
        "model_hparams": {},
        "act_function": str_ph(),  # add a new field that controls act function
        "shortcut_ch_shrink_method": str_ph(),
        "shortcut_ch_expand_method": str_ph(),
        "shortcut_spatial_method": str_ph(),
        "lr_scheduler": {
            "warmup_epochs": int_ph(),
            "cooldown_epochs": int_ph(),
            "scheduler": str_ph(),  # "cosine", "linear", or "step" lr decay
            "num_epochs": int_ph(),
        },
        "optimizer": str_ph(),
        "adam": {
            "beta1": float_ph(),
            "beta2": float_ph()
        },
        "early_stop_steps": int_ph(),
    })
    if use_auto_acts:
        # config_schema_utils is shared by wmt. To not make other code libraries
        # affected, add the new bound coefficients here.
        base_config.quant_act.bounds.update({
            "fixed_bound": float_ph(),
            "cams_coeff": float_ph(),
            "cams_stddev_coeff": float_ph(),
            "mean_of_max_coeff": float_ph(),
            "use_old_code": bool_ph(),
        })

    base_config.dense_layer = config_schema_utils.get_dense_config(base_config)
    # TODO(b/179063860): The input distribution is an intrinsic model
    # property and shouldn't be part of the model configuration. Update
    # the hparam dataclasses to eliminate the input_distribution field and
    # then delete this.
    base_config.dense_layer.quant_act.input_distribution = "positive"
    base_config.conv = config_schema_utils.get_conv_config(base_config)
    base_config.residual = get_residual_config(base_config)
    # make the activation function in a residual block consistent with the global
    # option
    config_schema_utils.set_default_reference(base_config.residual,
                                              base_config,
                                              field=[
                                                  "act_function",
                                                  "shortcut_ch_shrink_method",
                                                  "shortcut_ch_expand_method",
                                                  "shortcut_spatial_method"
                                              ])
    return base_config
def get_config(num_blocks, use_auto_acts):
    """Returns a ConfigDict instance for a Imagenet (Resnet50 and Resnet101).

  The ConfigDict is wired up so that changing a field at one level of the
  hierarchy changes the value of that field everywhere downstream in the
  hierarchy. For example, changing the top-level 'prec' parameter
  (eg, config.prec=4) will cause the precision of all layers to change.
  Changing the precision of a specific layer type
  (eg, config.residual_block.conv_1.weight_prec=4) will cause the weight prec
  of all conv_1 layers to change, overriding the value of the global
  config.prec value.

  See config_schema_test.test_schema_matches_expected to see the structure
  of the ConfigDict instance this will return.

  Args:
    num_blocks: Number of residual blocks in the architecture.
    use_auto_acts: Whether to use automatic clipping bounds for activations or
      fixed bounds. Unlike other properties of the configuration which can be
      overridden directly in the ConfigDict instance, this affects the immutable
      schema of the ConfigDict and so has to be specified before the ConfigDict
      is created.

  Returns:
    A ConfigDict instance which parallels the hierarchy of TrainingHParams.
  """
    base_config = get_base_config(use_auto_acts=use_auto_acts)
    model_hparams = base_config.model_hparams
    config_schema_utils.set_default_reference(model_hparams, base_config,
                                              "dense_layer")

    config_schema_utils.set_default_reference(model_hparams,
                                              base_config,
                                              "conv_init",
                                              parent_field="conv")
    model_hparams.residual_blocks = [
        config_schema_utils.make_reference(base_config, "residual")
        for _ in range(num_blocks)
    ]
    model_hparams.update({
        # Controls the number of parameters in the model by multiplying the number
        # of conv filters in each layer by this number.
        "filter_multiplier": float_ph(),
    })

    base_config.lock()
    return base_config
    def test_when_child_field_is_list(self):
        # Test when 'field' parameter of set_default_reference is a list
        # of specific fields. We expect a new reference to be created for each
        # element in the list.
        parent = ml_collections.ConfigDict({'x': 1, 'y': 2, 'z': 3})
        child = ml_collections.ConfigDict()
        config_schema_utils.set_default_reference(child, parent, ['x', 'y'])
        self.assertEqual((parent.x, parent.y), (1, 2))
        self.assertEqual((child.x, child.y), (1, 2))

        parent.y = 5
        self.assertEqual((parent.x, parent.y), (1, 5))
        self.assertEqual((child.x, child.y), (1, 5))

        child.y = 10
        self.assertEqual((parent.x, parent.y), (1, 5))
        self.assertEqual((child.x, child.y), (1, 10))
def get_residual_config(parent_config):
    """Creates ConfigDict corresponding to imagenet.models.ResidualBlock.HParams."""
    config = ml_collections.ConfigDict()
    config_schema_utils.set_default_reference(
        config,
        parent_config, ["conv_proj", "conv_1", "conv_2", "conv_3"],
        parent_field="conv")
    # TODO(b/179063860): The input distribution is an intrinsic model
    # property and shouldn't be part of the model configuration. Update
    # the hparam dataclasses to eliminate the input_distribution field and
    # then delete this.
    config.conv_proj.quant_act.input_distribution = "positive"
    config.conv_2.quant_act.input_distribution = "positive"
    config.conv_3.quant_act.input_distribution = "positive"

    config.lock()
    return config
    def test_reference_to_self(self):
        # Test adding a new field to a configdict which is a reference to an
        # existing field in the same configdict instance.
        config = ml_collections.ConfigDict({'parent': 1})
        config_schema_utils.set_default_reference(config,
                                                  config,
                                                  'child',
                                                  parent_field='parent')
        self.assertEqual(config.child, 1)
        self.assertEqual(config.parent, 1)

        config.parent = 5
        self.assertEqual(config.parent, 5)
        self.assertEqual(config.child, 5)

        config.child = 10
        self.assertEqual(config.parent, 5)
        self.assertEqual(config.child, 10)
def get_residual_config(parent_config):
    """Creates ConfigDict corresponding to imagenet.models.ResidualBlock.HParams."""
    config = ml_collections.ConfigDict()
    config_schema_utils.set_default_reference(
        config,
        parent_config, ["conv_se", "conv_proj", "conv_1", "conv_2", "conv_3"],
        parent_field="conv")
    # TODO(b/179063860): The input distribution is an intrinsic model
    # property and shouldn't be part of the model configuration. Update
    # the hparam dataclasses to eliminate the input_distribution field and
    # then delete this.
    config.conv_proj.quant_act.input_distribution = "positive"
    config.conv_2.quant_act.input_distribution = "positive"
    config.conv_3.quant_act.input_distribution = "positive"
    # add a new field in a residual block that control the act function
    config.update({
        "act_function": str_ph(),
        "shortcut_ch_shrink_method": str_ph(),
        "shortcut_ch_expand_method": str_ph(),
        "shortcut_spatial_method": str_ph(),
    })

    config.lock()
    return config
Esempio n. 10
0
def get_config(n_layers, use_auto_acts, fp_quant):
    """Returns a ConfigDict instance for a WMT transformer.

  The ConfigDict is wired up so that changing a field at one level of the
  hierarchy changes the value of that field everywhere downstream in the
  hierarchy. For example, changing the top-level 'prec' parameter
  (eg, config.prec=4) will cause the precision of all layers to change.
  Changing the precision of a specific layer type
  (eg, config.mlp_block.dense_1.weight_prec=4) will cause the weight precision
  of all Dense1 layers to change, overriding the value of the global
  config.prec value.

  See config_schema_test.test_schema_matches_expected to see the structure
  of the ConfigDict instance this will return.

  Args:
    n_layers: Number of layers in the encoder and the decoder.
    use_auto_acts: Whether to use automatic clipping bounds for activations or
      fixed bounds. Unlike other properties of the configuration which can be
      overridden directly in the ConfigDict instance, this affects the immutable
      schema of the ConfigDict and so has to be specified before the ConfigDict
      is created.
    fp_quant: Whether to use floating point quantization. Defaults to False for
      integer quantization.

  Returns:
    A ConfigDict instance which parallels the hierarchy of TrainingHParams.
  """
    base_config = get_wmt_base_config(use_auto_acts=use_auto_acts,
                                      fp_quant=fp_quant)
    model_hparams = base_config.model_hparams
    model_hparams.encoder = {
        "encoder_1d_blocks": [
            get_block_config(base_config, BlockKind.encoder)
            for _ in range(n_layers)
        ]
    }
    config_schema_utils.set_default_reference(model_hparams.encoder,
                                              base_config, "embedding")
    model_hparams.decoder = {
        "encoder_decoder_1d_blocks": [
            get_block_config(base_config, BlockKind.decoder)
            for _ in range(n_layers)
        ]
    }
    config_schema_utils.set_default_reference(model_hparams.decoder,
                                              base_config, "embedding")

    config_schema_utils.set_default_reference(model_hparams.decoder,
                                              base_config,
                                              "logits",
                                              parent_field="dense")
    base_config.lock()
    return base_config
Esempio n. 11
0
def get_block_config(parent_config, block_kind):
    """Create a ConfigDict corresponding to wmt_mlperf.models.Encoder[Decoder]1DBlock.HParams."""
    config = ml_collections.ConfigDict()
    config_schema_utils.set_default_reference(config, parent_config,
                                              "mlp_block")
    if block_kind == BlockKind.encoder:
        config_schema_utils.set_default_reference(config, parent_config,
                                                  "attention")
    elif block_kind == BlockKind.decoder:
        config_schema_utils.set_default_reference(
            config,
            parent_config, ["self_attention", "enc_dec_attention"],
            parent_field="attention")
    else:
        raise ValueError(f"Unknown block_kind {block_kind}")
    config.lock()
    return config