コード例 #1
0
def _create_albert_model(cfg):
    """Creates an ALBERT keras core model from BERT configuration.

  Args:
    cfg: A `AlbertConfig` to create the core model.

  Returns:
    A keras model.
  """
    albert_encoder = networks.AlbertEncoder(
        vocab_size=cfg.vocab_size,
        hidden_size=cfg.hidden_size,
        embedding_width=cfg.embedding_size,
        num_layers=cfg.num_hidden_layers,
        num_attention_heads=cfg.num_attention_heads,
        intermediate_size=cfg.intermediate_size,
        activation=tf_utils.get_activation(cfg.hidden_act),
        dropout_rate=cfg.hidden_dropout_prob,
        attention_dropout_rate=cfg.attention_probs_dropout_prob,
        max_sequence_length=cfg.max_position_embeddings,
        type_vocab_size=cfg.type_vocab_size,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=cfg.initializer_range))
    return albert_encoder
コード例 #2
0
ファイル: encoders.py プロジェクト: xiangww00/models
def build_encoder(config: EncoderConfig,
                  embedding_layer: Optional[tf.keras.layers.Layer] = None,
                  encoder_cls=None,
                  bypass_config: bool = False):
    """Instantiate a Transformer encoder network from EncoderConfig.

  Args:
    config: the one-of encoder config, which provides encoder parameters of a
      chosen encoder.
    embedding_layer: an external embedding layer passed to the encoder.
    encoder_cls: an external encoder cls not included in the supported encoders,
      usually used by gin.configurable.
    bypass_config: whether to ignore config instance to create the object with
      `encoder_cls`.

  Returns:
    An encoder instance.
  """
    if bypass_config:
        return encoder_cls()
    encoder_type = config.type
    encoder_cfg = config.get()
    if encoder_cls and encoder_cls.__name__ == "EncoderScaffold":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
        )
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
            dict_outputs=True)
        return encoder_cls(**kwargs)

    if encoder_type == "mobilebert":
        return networks.MobileBERTEncoder(
            word_vocab_size=encoder_cfg.word_vocab_size,
            word_embed_size=encoder_cfg.word_embed_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            max_sequence_length=encoder_cfg.max_sequence_length,
            num_blocks=encoder_cfg.num_blocks,
            hidden_size=encoder_cfg.hidden_size,
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_act_fn=encoder_cfg.hidden_activation,
            hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
            attention_probs_dropout_prob=encoder_cfg.
            attention_probs_dropout_prob,
            intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
            initializer_range=encoder_cfg.initializer_range,
            use_bottleneck_attention=encoder_cfg.use_bottleneck_attention,
            key_query_shared_bottleneck=encoder_cfg.
            key_query_shared_bottleneck,
            num_feedforward_networks=encoder_cfg.num_feedforward_networks,
            normalization_type=encoder_cfg.normalization_type,
            classifier_activation=encoder_cfg.classifier_activation,
            input_mask_dtype=encoder_cfg.input_mask_dtype)

    if encoder_type == "albert":
        return networks.AlbertEncoder(
            vocab_size=encoder_cfg.vocab_size,
            embedding_width=encoder_cfg.embedding_width,
            hidden_size=encoder_cfg.hidden_size,
            num_layers=encoder_cfg.num_layers,
            num_attention_heads=encoder_cfg.num_attention_heads,
            max_sequence_length=encoder_cfg.max_position_embeddings,
            type_vocab_size=encoder_cfg.type_vocab_size,
            intermediate_size=encoder_cfg.intermediate_size,
            activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dict_outputs=True)

    if encoder_type == "bigbird":
        # TODO(frederickliu): Support use_gradient_checkpointing and update
        # experiments to use the EncoderScaffold only.
        if encoder_cfg.use_gradient_checkpointing:
            return bigbird_encoder.BigBirdEncoder(
                vocab_size=encoder_cfg.vocab_size,
                hidden_size=encoder_cfg.hidden_size,
                num_layers=encoder_cfg.num_layers,
                num_attention_heads=encoder_cfg.num_attention_heads,
                intermediate_size=encoder_cfg.intermediate_size,
                activation=tf_utils.get_activation(
                    encoder_cfg.hidden_activation),
                dropout_rate=encoder_cfg.dropout_rate,
                attention_dropout_rate=encoder_cfg.attention_dropout_rate,
                num_rand_blocks=encoder_cfg.num_rand_blocks,
                block_size=encoder_cfg.block_size,
                max_position_embeddings=encoder_cfg.max_position_embeddings,
                type_vocab_size=encoder_cfg.type_vocab_size,
                initializer=tf.keras.initializers.TruncatedNormal(
                    stddev=encoder_cfg.initializer_range),
                embedding_width=encoder_cfg.embedding_width,
                use_gradient_checkpointing=encoder_cfg.
                use_gradient_checkpointing)
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate)
        attention_cfg = dict(
            num_heads=encoder_cfg.num_attention_heads,
            key_dim=int(encoder_cfg.hidden_size //
                        encoder_cfg.num_attention_heads),
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            max_rand_mask_length=encoder_cfg.max_position_embeddings,
            num_rand_blocks=encoder_cfg.num_rand_blocks,
            from_block_size=encoder_cfg.block_size,
            to_block_size=encoder_cfg.block_size,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            norm_first=encoder_cfg.norm_first,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            attention_cls=layers.BigBirdAttention,
            attention_cfg=attention_cfg)
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cls=layers.TransformerScaffold,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            mask_cls=layers.BigBirdMasks,
            mask_cfg=dict(block_size=encoder_cfg.block_size),
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=False,
            dict_outputs=True,
            layer_idx_as_attention_seed=True)
        return networks.EncoderScaffold(**kwargs)

    if encoder_type == "kernel":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate)
        attention_cfg = dict(
            num_heads=encoder_cfg.num_attention_heads,
            key_dim=int(encoder_cfg.hidden_size //
                        encoder_cfg.num_attention_heads),
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            feature_transform=encoder_cfg.feature_transform,
            num_random_features=encoder_cfg.num_random_features,
            redraw=encoder_cfg.redraw,
            is_short_seq=encoder_cfg.is_short_seq,
            begin_kernel=encoder_cfg.begin_kernel,
            scale=encoder_cfg.scale,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            norm_first=encoder_cfg.norm_first,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            attention_cls=layers.KernelAttention,
            attention_cfg=attention_cfg)
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cls=layers.TransformerScaffold,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            mask_cls=layers.KernelMask,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=False,
            dict_outputs=True,
            layer_idx_as_attention_seed=True)
        return networks.EncoderScaffold(**kwargs)

    if encoder_type == "xlnet":
        return networks.XLNetBase(
            vocab_size=encoder_cfg.vocab_size,
            num_layers=encoder_cfg.num_layers,
            hidden_size=encoder_cfg.hidden_size,
            num_attention_heads=encoder_cfg.num_attention_heads,
            head_size=encoder_cfg.head_size,
            inner_size=encoder_cfg.inner_size,
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            attention_type=encoder_cfg.attention_type,
            bi_data=encoder_cfg.bi_data,
            two_stream=encoder_cfg.two_stream,
            tie_attention_biases=encoder_cfg.tie_attention_biases,
            memory_length=encoder_cfg.memory_length,
            clamp_length=encoder_cfg.clamp_length,
            reuse_length=encoder_cfg.reuse_length,
            inner_activation=encoder_cfg.inner_activation,
            use_cls_mask=encoder_cfg.use_cls_mask,
            embedding_width=encoder_cfg.embedding_width,
            initializer=tf.keras.initializers.RandomNormal(
                stddev=encoder_cfg.initializer_range))

    if encoder_type == "teams":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            embedding_width=encoder_cfg.embedding_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate,
        )
        embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
        )
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            embedding_cls=embedding_network,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
            dict_outputs=True)
        return networks.EncoderScaffold(**kwargs)

    # Uses the default BERTEncoder configuration schema to create the encoder.
    # If it does not match, please add a switch branch by the encoder type.
    return networks.BertEncoder(
        vocab_size=encoder_cfg.vocab_size,
        hidden_size=encoder_cfg.hidden_size,
        num_layers=encoder_cfg.num_layers,
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
        activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
        max_sequence_length=encoder_cfg.max_position_embeddings,
        type_vocab_size=encoder_cfg.type_vocab_size,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=encoder_cfg.initializer_range),
        output_range=encoder_cfg.output_range,
        embedding_width=encoder_cfg.embedding_size,
        embedding_layer=embedding_layer,
        return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
        dict_outputs=True,
        norm_first=encoder_cfg.norm_first)
コード例 #3
0
def get_transformer_encoder(bert_config,
                            sequence_length=None,
                            transformer_encoder_cls=None,
                            output_range=None):
    """Gets a 'TransformerEncoder' object.

  Args:
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
    sequence_length: [Deprecated].
    transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
      default BERT encoder implementation.
    output_range: the sequence output range, [0, output_range). Default setting
      is to return the entire sequence output.

  Returns:
    A encoder object.
  """
    del sequence_length
    if transformer_encoder_cls is not None:
        # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
        embedding_cfg = dict(
            vocab_size=bert_config.vocab_size,
            type_vocab_size=bert_config.type_vocab_size,
            hidden_size=bert_config.hidden_size,
            max_seq_length=bert_config.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=bert_config.initializer_range),
            dropout_rate=bert_config.hidden_dropout_prob,
        )
        hidden_cfg = dict(
            num_attention_heads=bert_config.num_attention_heads,
            intermediate_size=bert_config.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                bert_config.hidden_act),
            dropout_rate=bert_config.hidden_dropout_prob,
            attention_dropout_rate=bert_config.attention_probs_dropout_prob,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=bert_config.initializer_range),
        )
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=bert_config.num_hidden_layers,
            pooled_output_dim=bert_config.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=bert_config.initializer_range))

        # Relies on gin configuration to define the Transformer encoder arguments.
        return transformer_encoder_cls(**kwargs)

    kwargs = dict(
        vocab_size=bert_config.vocab_size,
        hidden_size=bert_config.hidden_size,
        num_layers=bert_config.num_hidden_layers,
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        activation=tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate=bert_config.hidden_dropout_prob,
        attention_dropout_rate=bert_config.attention_probs_dropout_prob,
        max_sequence_length=bert_config.max_position_embeddings,
        type_vocab_size=bert_config.type_vocab_size,
        embedding_width=bert_config.embedding_size,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range))
    if isinstance(bert_config, albert_configs.AlbertConfig):
        return networks.AlbertEncoder(**kwargs)
    else:
        assert isinstance(bert_config, configs.BertConfig)
        kwargs['output_range'] = output_range
        return networks.BertEncoder(**kwargs)