コード例 #1
0
  def __init__(self, xlnet_config, run_config, start_n_top, end_n_top,
               **kwargs):
    super(QAXLNetModel, self).__init__(**kwargs)
    self.run_config = run_config
    self.initializer = _get_initializer(run_config)
    self.xlnet_config = copy.deepcopy(xlnet_config)

    self.xlnet_model = networks.XLNetBase(
        vocab_size=self.xlnet_config.n_token,
        initializer=self.initializer,
        attention_type="bi",
        num_layers=self.xlnet_config.n_layer,
        hidden_size=self.xlnet_config.d_model,
        num_attention_heads=self.xlnet_config.n_head,
        head_size=self.xlnet_config.d_head,
        inner_size=self.xlnet_config.d_inner,
        tie_attention_biases=not self.xlnet_config.untie_r,
        inner_activation=self.xlnet_config.ff_activation,
        dropout_rate=self.run_config.dropout,
        attention_dropout_rate=self.run_config.dropout_att,
        two_stream=False,
        memory_length=self.run_config.mem_len,
        reuse_length=self.run_config.reuse_len,
        bi_data=self.run_config.bi_data,
        clamp_length=self.run_config.clamp_len,
        use_cls_mask=False,
        name="xlnet_model")

    self.qa_loss_layer = QALossLayer(
        hidden_size=self.xlnet_config.d_model,
        start_n_top=start_n_top,
        end_n_top=end_n_top,
        initializer=self.initializer,
        dropout_rate=self.run_config.dropout,
        name="qa_loss_layer")
コード例 #2
0
  def __init__(self, xlnet_config, run_config, n_class, summary_type,
               use_legacy_mask=True, **kwargs):
    super(ClassificationXLNetModel, self).__init__(**kwargs)
    warnings.warn(
        "`ClassificationXLNetModel` is deprecated, please use `XLNetClassifier`"
        "instead.", DeprecationWarning, stacklevel=2)
    self.run_config = run_config
    self.initializer = _get_initializer(run_config)
    self.xlnet_config = copy.deepcopy(xlnet_config)
    self._use_legacy_mask = use_legacy_mask

    self.xlnet_model = networks.XLNetBase(
        vocab_size=self.xlnet_config.n_token,
        initializer=self.initializer,
        attention_type="bi",
        num_layers=self.xlnet_config.n_layer,
        hidden_size=self.xlnet_config.d_model,
        num_attention_heads=self.xlnet_config.n_head,
        head_size=self.xlnet_config.d_head,
        inner_size=self.xlnet_config.d_inner,
        two_stream=False,
        tie_attention_biases=not self.xlnet_config.untie_r,
        inner_activation=self.xlnet_config.ff_activation,
        dropout_rate=self.run_config.dropout,
        attention_dropout_rate=self.run_config.dropout_att,
        memory_length=self.run_config.mem_len,
        reuse_length=self.run_config.reuse_len,
        bi_data=self.run_config.bi_data,
        clamp_length=self.run_config.clamp_len,
        use_cls_mask=False,
        name="xlnet_model")

    self.summarization_layer = Summarization(
        hidden_size=self.xlnet_config.d_model,
        num_attention_heads=self.xlnet_config.n_head,
        head_size=self.xlnet_config.d_head,
        dropout_rate=self.run_config.dropout,
        attention_dropout_rate=self.run_config.dropout_att,
        initializer=self.initializer,
        use_proj=True,
        summary_type=summary_type,
        name="sequence_summary")

    self.cl_loss_layer = ClassificationLossLayer(
        n_class=n_class, initializer=self.initializer, name="classification")
コード例 #3
0
def _get_xlnet_base() -> tf.keras.layers.Layer:
    """Returns a trivial base XLNet model."""
    return networks.XLNetBase(
        vocab_size=100,
        num_layers=2,
        hidden_size=4,
        num_attention_heads=2,
        head_size=2,
        inner_size=2,
        dropout_rate=0.,
        attention_dropout_rate=0.,
        attention_type='bi',
        bi_data=True,
        initializer=tf.keras.initializers.RandomNormal(stddev=0.1),
        two_stream=False,
        tie_attention_biases=True,
        reuse_length=0,
        inner_activation='relu')
コード例 #4
0
def get_xlnet_base(model_config: xlnet_config.XLNetConfig,
                   run_config: xlnet_config.RunConfig, attention_type: str,
                   two_stream: bool, use_cls_mask: bool) -> tf.keras.Model:
    """Gets an 'XLNetBase' object.

  Args:
    model_config: the config that defines the core XLNet model.
    run_config: separate runtime configuration with extra parameters.
    attention_type: the attention type for the base XLNet model, "uni" or "bi".
    two_stream: whether or not to use two strema attention.
    use_cls_mask: whether or not cls mask is included in the input sequences.

  Returns:
    An XLNetBase object.
  """
    initializer = _get_initializer(
        initialization_method=run_config.init_method,
        initialization_range=run_config.init_range,
        initialization_std=run_config.init_std)
    kwargs = dict(vocab_size=model_config.n_token,
                  num_layers=model_config.n_layer,
                  hidden_size=model_config.d_model,
                  num_attention_heads=model_config.n_head,
                  head_size=model_config.d_head,
                  inner_size=model_config.d_inner,
                  dropout_rate=run_config.dropout,
                  attention_dropout_rate=run_config.dropout_att,
                  attention_type=attention_type,
                  bi_data=run_config.bi_data,
                  initializer=initializer,
                  two_stream=two_stream,
                  tie_attention_biases=not model_config.untie_r,
                  memory_length=run_config.mem_len,
                  clamp_length=run_config.clamp_len,
                  reuse_length=run_config.reuse_len,
                  inner_activation=model_config.ff_activation,
                  use_cls_mask=use_cls_mask)
    return networks.XLNetBase(**kwargs)
コード例 #5
0
  def __init__(self, use_proj, xlnet_config, run_config, use_legacy_mask=True,
               **kwargs):
    super(PretrainingXLNetModel, self).__init__(**kwargs)
    self.run_config = run_config
    self.initializer = _get_initializer(run_config)
    self.xlnet_config = copy.deepcopy(xlnet_config)
    self._use_legacy_mask = use_legacy_mask

    self.xlnet_model = networks.XLNetBase(
        vocab_size=self.xlnet_config.n_token,
        initializer=self.initializer,
        attention_type="bi",
        num_layers=self.xlnet_config.n_layer,
        hidden_size=self.xlnet_config.d_model,
        num_attention_heads=self.xlnet_config.n_head,
        head_size=self.xlnet_config.d_head,
        inner_size=self.xlnet_config.d_inner,
        two_stream=True,
        tie_attention_biases=not self.xlnet_config.untie_r,
        inner_activation=self.xlnet_config.ff_activation,
        dropout_rate=self.run_config.dropout,
        attention_dropout_rate=self.run_config.dropout_att,
        memory_length=self.run_config.mem_len,
        reuse_length=self.run_config.reuse_len,
        bi_data=self.run_config.bi_data,
        clamp_length=self.run_config.clamp_len,
        use_cls_mask=self.run_config.use_cls_mask,
        name="xlnet_model")

    self.lmloss_layer = LMLossLayer(
        vocab_size=self.xlnet_config.n_token,
        hidden_size=self.xlnet_config.d_model,
        initializer=self.initializer,
        tie_weight=True,
        bi_data=self.run_config.bi_data,
        use_one_hot=self.run_config.use_tpu,
        use_proj=use_proj,
        name="lm_loss")
コード例 #6
0
ファイル: encoders.py プロジェクト: xiangww00/models
def build_encoder(config: EncoderConfig,
                  embedding_layer: Optional[tf.keras.layers.Layer] = None,
                  encoder_cls=None,
                  bypass_config: bool = False):
    """Instantiate a Transformer encoder network from EncoderConfig.

  Args:
    config: the one-of encoder config, which provides encoder parameters of a
      chosen encoder.
    embedding_layer: an external embedding layer passed to the encoder.
    encoder_cls: an external encoder cls not included in the supported encoders,
      usually used by gin.configurable.
    bypass_config: whether to ignore config instance to create the object with
      `encoder_cls`.

  Returns:
    An encoder instance.
  """
    if bypass_config:
        return encoder_cls()
    encoder_type = config.type
    encoder_cfg = config.get()
    if encoder_cls and encoder_cls.__name__ == "EncoderScaffold":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
        )
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
            dict_outputs=True)
        return encoder_cls(**kwargs)

    if encoder_type == "mobilebert":
        return networks.MobileBERTEncoder(
            word_vocab_size=encoder_cfg.word_vocab_size,
            word_embed_size=encoder_cfg.word_embed_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            max_sequence_length=encoder_cfg.max_sequence_length,
            num_blocks=encoder_cfg.num_blocks,
            hidden_size=encoder_cfg.hidden_size,
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_act_fn=encoder_cfg.hidden_activation,
            hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
            attention_probs_dropout_prob=encoder_cfg.
            attention_probs_dropout_prob,
            intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
            initializer_range=encoder_cfg.initializer_range,
            use_bottleneck_attention=encoder_cfg.use_bottleneck_attention,
            key_query_shared_bottleneck=encoder_cfg.
            key_query_shared_bottleneck,
            num_feedforward_networks=encoder_cfg.num_feedforward_networks,
            normalization_type=encoder_cfg.normalization_type,
            classifier_activation=encoder_cfg.classifier_activation,
            input_mask_dtype=encoder_cfg.input_mask_dtype)

    if encoder_type == "albert":
        return networks.AlbertEncoder(
            vocab_size=encoder_cfg.vocab_size,
            embedding_width=encoder_cfg.embedding_width,
            hidden_size=encoder_cfg.hidden_size,
            num_layers=encoder_cfg.num_layers,
            num_attention_heads=encoder_cfg.num_attention_heads,
            max_sequence_length=encoder_cfg.max_position_embeddings,
            type_vocab_size=encoder_cfg.type_vocab_size,
            intermediate_size=encoder_cfg.intermediate_size,
            activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dict_outputs=True)

    if encoder_type == "bigbird":
        # TODO(frederickliu): Support use_gradient_checkpointing and update
        # experiments to use the EncoderScaffold only.
        if encoder_cfg.use_gradient_checkpointing:
            return bigbird_encoder.BigBirdEncoder(
                vocab_size=encoder_cfg.vocab_size,
                hidden_size=encoder_cfg.hidden_size,
                num_layers=encoder_cfg.num_layers,
                num_attention_heads=encoder_cfg.num_attention_heads,
                intermediate_size=encoder_cfg.intermediate_size,
                activation=tf_utils.get_activation(
                    encoder_cfg.hidden_activation),
                dropout_rate=encoder_cfg.dropout_rate,
                attention_dropout_rate=encoder_cfg.attention_dropout_rate,
                num_rand_blocks=encoder_cfg.num_rand_blocks,
                block_size=encoder_cfg.block_size,
                max_position_embeddings=encoder_cfg.max_position_embeddings,
                type_vocab_size=encoder_cfg.type_vocab_size,
                initializer=tf.keras.initializers.TruncatedNormal(
                    stddev=encoder_cfg.initializer_range),
                embedding_width=encoder_cfg.embedding_width,
                use_gradient_checkpointing=encoder_cfg.
                use_gradient_checkpointing)
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate)
        attention_cfg = dict(
            num_heads=encoder_cfg.num_attention_heads,
            key_dim=int(encoder_cfg.hidden_size //
                        encoder_cfg.num_attention_heads),
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            max_rand_mask_length=encoder_cfg.max_position_embeddings,
            num_rand_blocks=encoder_cfg.num_rand_blocks,
            from_block_size=encoder_cfg.block_size,
            to_block_size=encoder_cfg.block_size,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            norm_first=encoder_cfg.norm_first,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            attention_cls=layers.BigBirdAttention,
            attention_cfg=attention_cfg)
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cls=layers.TransformerScaffold,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            mask_cls=layers.BigBirdMasks,
            mask_cfg=dict(block_size=encoder_cfg.block_size),
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=False,
            dict_outputs=True,
            layer_idx_as_attention_seed=True)
        return networks.EncoderScaffold(**kwargs)

    if encoder_type == "kernel":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate)
        attention_cfg = dict(
            num_heads=encoder_cfg.num_attention_heads,
            key_dim=int(encoder_cfg.hidden_size //
                        encoder_cfg.num_attention_heads),
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            feature_transform=encoder_cfg.feature_transform,
            num_random_features=encoder_cfg.num_random_features,
            redraw=encoder_cfg.redraw,
            is_short_seq=encoder_cfg.is_short_seq,
            begin_kernel=encoder_cfg.begin_kernel,
            scale=encoder_cfg.scale,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            norm_first=encoder_cfg.norm_first,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            attention_cls=layers.KernelAttention,
            attention_cfg=attention_cfg)
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cls=layers.TransformerScaffold,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            mask_cls=layers.KernelMask,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=False,
            dict_outputs=True,
            layer_idx_as_attention_seed=True)
        return networks.EncoderScaffold(**kwargs)

    if encoder_type == "xlnet":
        return networks.XLNetBase(
            vocab_size=encoder_cfg.vocab_size,
            num_layers=encoder_cfg.num_layers,
            hidden_size=encoder_cfg.hidden_size,
            num_attention_heads=encoder_cfg.num_attention_heads,
            head_size=encoder_cfg.head_size,
            inner_size=encoder_cfg.inner_size,
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            attention_type=encoder_cfg.attention_type,
            bi_data=encoder_cfg.bi_data,
            two_stream=encoder_cfg.two_stream,
            tie_attention_biases=encoder_cfg.tie_attention_biases,
            memory_length=encoder_cfg.memory_length,
            clamp_length=encoder_cfg.clamp_length,
            reuse_length=encoder_cfg.reuse_length,
            inner_activation=encoder_cfg.inner_activation,
            use_cls_mask=encoder_cfg.use_cls_mask,
            embedding_width=encoder_cfg.embedding_width,
            initializer=tf.keras.initializers.RandomNormal(
                stddev=encoder_cfg.initializer_range))

    if encoder_type == "teams":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            embedding_width=encoder_cfg.embedding_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate,
        )
        embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
        )
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            embedding_cls=embedding_network,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
            dict_outputs=True)
        return networks.EncoderScaffold(**kwargs)

    # Uses the default BERTEncoder configuration schema to create the encoder.
    # If it does not match, please add a switch branch by the encoder type.
    return networks.BertEncoder(
        vocab_size=encoder_cfg.vocab_size,
        hidden_size=encoder_cfg.hidden_size,
        num_layers=encoder_cfg.num_layers,
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
        activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
        max_sequence_length=encoder_cfg.max_position_embeddings,
        type_vocab_size=encoder_cfg.type_vocab_size,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=encoder_cfg.initializer_range),
        output_range=encoder_cfg.output_range,
        embedding_width=encoder_cfg.embedding_size,
        embedding_layer=embedding_layer,
        return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
        dict_outputs=True,
        norm_first=encoder_cfg.norm_first)