def test_multiple_cls_outputs(self):
        """Validate that the Keras object can be created."""
        # Build a transformer network to use within the BERT trainer.
        vocab_size = 100
        sequence_length = 512
        hidden_size = 48
        num_layers = 2
        test_network = networks.BertEncoderV2(
            vocab_size=vocab_size,
            num_layers=num_layers,
            hidden_size=hidden_size,
            max_sequence_length=sequence_length)

        bert_trainer_model = bert_pretrainer.BertPretrainerV2(
            encoder_network=test_network,
            classification_heads=[
                layers.MultiClsHeads(inner_dim=5,
                                     cls_list=[('foo', 2), ('bar', 3)])
            ])
        num_token_predictions = 20
        # Create a set of 2-dimensional inputs (the first dimension is implicit).
        inputs = dict(input_word_ids=tf.keras.Input(shape=(sequence_length, ),
                                                    dtype=tf.int32),
                      input_mask=tf.keras.Input(shape=(sequence_length, ),
                                                dtype=tf.int32),
                      input_type_ids=tf.keras.Input(shape=(sequence_length, ),
                                                    dtype=tf.int32),
                      masked_lm_positions=tf.keras.Input(
                          shape=(num_token_predictions, ), dtype=tf.int32))

        # Invoke the trainer model on the inputs. This causes the layer to be built.
        outputs = bert_trainer_model(inputs)
        self.assertEqual(outputs['foo'].shape.as_list(), [None, 2])
        self.assertEqual(outputs['bar'].shape.as_list(), [None, 3])
    def test_v2_serialize_deserialize(self):
        """Validate that the BERT trainer can be serialized and deserialized."""
        # Build a transformer network to use within the BERT trainer.
        test_network = networks.BertEncoderV2(vocab_size=100, num_layers=2)

        # Create a BERT trainer with the created network. (Note that all the args
        # are different, so we can catch any serialization mismatches.)
        bert_trainer_model = bert_pretrainer.BertPretrainerV2(
            encoder_network=test_network)

        # Create another BERT trainer via serialization and deserialization.
        config = bert_trainer_model.get_config()
        new_bert_trainer_model = bert_pretrainer.BertPretrainerV2.from_config(
            config)

        # Validate that the config can be forced to JSON.
        _ = new_bert_trainer_model.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(bert_trainer_model.get_config(),
                            new_bert_trainer_model.get_config())
    def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs,
                               use_customized_masked_lm,
                               has_masked_lm_positions):
        """Validate that the Keras object can be created."""
        # Build a transformer network to use within the BERT trainer.
        vocab_size = 100
        sequence_length = 512
        hidden_size = 48
        num_layers = 2
        test_network = networks.BertEncoderV2(
            vocab_size=vocab_size,
            num_layers=num_layers,
            hidden_size=hidden_size,
            max_sequence_length=sequence_length)
        _ = test_network(test_network.inputs)

        # Create a BERT trainer with the created network.
        if use_customized_masked_lm:
            customized_masked_lm = layers.MaskedLM(
                embedding_table=test_network.get_embedding_table())
        else:
            customized_masked_lm = None

        bert_trainer_model = bert_pretrainer.BertPretrainerV2(
            encoder_network=test_network,
            customized_masked_lm=customized_masked_lm)
        num_token_predictions = 20
        # Create a set of 2-dimensional inputs (the first dimension is implicit).
        inputs = dict(input_word_ids=tf.keras.Input(shape=(sequence_length, ),
                                                    dtype=tf.int32),
                      input_mask=tf.keras.Input(shape=(sequence_length, ),
                                                dtype=tf.int32),
                      input_type_ids=tf.keras.Input(shape=(sequence_length, ),
                                                    dtype=tf.int32))
        if has_masked_lm_positions:
            inputs['masked_lm_positions'] = tf.keras.Input(
                shape=(num_token_predictions, ), dtype=tf.int32)

        # Invoke the trainer model on the inputs. This causes the layer to be built.
        outputs = bert_trainer_model(inputs)

        has_encoder_outputs = True  # dict_outputs or return_all_encoder_outputs
        expected_keys = ['sequence_output', 'pooled_output']
        if has_encoder_outputs:
            expected_keys.append('encoder_outputs')
        if has_masked_lm_positions:
            expected_keys.append('mlm_logits')

        self.assertSameElements(outputs.keys(), expected_keys)
        # Validate that the outputs are of the expected shape.
        expected_lm_shape = [None, num_token_predictions, vocab_size]
        if has_masked_lm_positions:
            self.assertAllEqual(expected_lm_shape,
                                outputs['mlm_logits'].shape.as_list())

        expected_sequence_output_shape = [None, sequence_length, hidden_size]
        self.assertAllEqual(expected_sequence_output_shape,
                            outputs['sequence_output'].shape.as_list())

        expected_pooled_output_shape = [None, hidden_size]
        self.assertAllEqual(expected_pooled_output_shape,
                            outputs['pooled_output'].shape.as_list())
Пример #4
0
def build_encoder(config: EncoderConfig,
                  embedding_layer: Optional[tf.keras.layers.Layer] = None,
                  encoder_cls=None,
                  bypass_config: bool = False):
    """Instantiate a Transformer encoder network from EncoderConfig.

  Args:
    config: the one-of encoder config, which provides encoder parameters of a
      chosen encoder.
    embedding_layer: an external embedding layer passed to the encoder.
    encoder_cls: an external encoder cls not included in the supported encoders,
      usually used by gin.configurable.
    bypass_config: whether to ignore config instance to create the object with
      `encoder_cls`.

  Returns:
    An encoder instance.
  """
    if bypass_config:
        return encoder_cls()
    encoder_type = config.type
    encoder_cfg = config.get()
    if encoder_cls and encoder_cls.__name__ == "EncoderScaffold":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
        )
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
            dict_outputs=True)
        return encoder_cls(**kwargs)

    if encoder_type == "any":
        encoder = encoder_cfg.BUILDER(encoder_cfg)
        if not isinstance(encoder,
                          (tf.Module, tf.keras.Model, tf.keras.layers.Layer)):
            raise ValueError(
                "The BUILDER returns an unexpected instance. The "
                "`build_encoder` should returns a tf.Module, "
                "tf.keras.Model or tf.keras.layers.Layer. However, "
                f"we get {encoder.__class__}")
        return encoder

    if encoder_type == "mobilebert":
        return networks.MobileBERTEncoder(
            word_vocab_size=encoder_cfg.word_vocab_size,
            word_embed_size=encoder_cfg.word_embed_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            max_sequence_length=encoder_cfg.max_sequence_length,
            num_blocks=encoder_cfg.num_blocks,
            hidden_size=encoder_cfg.hidden_size,
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_act_fn=encoder_cfg.hidden_activation,
            hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
            attention_probs_dropout_prob=encoder_cfg.
            attention_probs_dropout_prob,
            intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
            initializer_range=encoder_cfg.initializer_range,
            use_bottleneck_attention=encoder_cfg.use_bottleneck_attention,
            key_query_shared_bottleneck=encoder_cfg.
            key_query_shared_bottleneck,
            num_feedforward_networks=encoder_cfg.num_feedforward_networks,
            normalization_type=encoder_cfg.normalization_type,
            classifier_activation=encoder_cfg.classifier_activation,
            input_mask_dtype=encoder_cfg.input_mask_dtype)

    if encoder_type == "albert":
        return networks.AlbertEncoder(
            vocab_size=encoder_cfg.vocab_size,
            embedding_width=encoder_cfg.embedding_width,
            hidden_size=encoder_cfg.hidden_size,
            num_layers=encoder_cfg.num_layers,
            num_attention_heads=encoder_cfg.num_attention_heads,
            max_sequence_length=encoder_cfg.max_position_embeddings,
            type_vocab_size=encoder_cfg.type_vocab_size,
            intermediate_size=encoder_cfg.intermediate_size,
            activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dict_outputs=True)

    if encoder_type == "bigbird":
        # TODO(frederickliu): Support use_gradient_checkpointing and update
        # experiments to use the EncoderScaffold only.
        if encoder_cfg.use_gradient_checkpointing:
            return bigbird_encoder.BigBirdEncoder(
                vocab_size=encoder_cfg.vocab_size,
                hidden_size=encoder_cfg.hidden_size,
                num_layers=encoder_cfg.num_layers,
                num_attention_heads=encoder_cfg.num_attention_heads,
                intermediate_size=encoder_cfg.intermediate_size,
                activation=tf_utils.get_activation(
                    encoder_cfg.hidden_activation),
                dropout_rate=encoder_cfg.dropout_rate,
                attention_dropout_rate=encoder_cfg.attention_dropout_rate,
                num_rand_blocks=encoder_cfg.num_rand_blocks,
                block_size=encoder_cfg.block_size,
                max_position_embeddings=encoder_cfg.max_position_embeddings,
                type_vocab_size=encoder_cfg.type_vocab_size,
                initializer=tf.keras.initializers.TruncatedNormal(
                    stddev=encoder_cfg.initializer_range),
                embedding_width=encoder_cfg.embedding_width,
                use_gradient_checkpointing=encoder_cfg.
                use_gradient_checkpointing)
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate)
        attention_cfg = dict(
            num_heads=encoder_cfg.num_attention_heads,
            key_dim=int(encoder_cfg.hidden_size //
                        encoder_cfg.num_attention_heads),
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            max_rand_mask_length=encoder_cfg.max_position_embeddings,
            num_rand_blocks=encoder_cfg.num_rand_blocks,
            from_block_size=encoder_cfg.block_size,
            to_block_size=encoder_cfg.block_size,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            norm_first=encoder_cfg.norm_first,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            attention_cls=layers.BigBirdAttention,
            attention_cfg=attention_cfg)
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cls=layers.TransformerScaffold,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            mask_cls=layers.BigBirdMasks,
            mask_cfg=dict(block_size=encoder_cfg.block_size),
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=False,
            dict_outputs=True,
            layer_idx_as_attention_seed=True)
        return networks.EncoderScaffold(**kwargs)

    if encoder_type == "kernel":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate)
        attention_cfg = dict(
            num_heads=encoder_cfg.num_attention_heads,
            key_dim=int(encoder_cfg.hidden_size //
                        encoder_cfg.num_attention_heads),
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            feature_transform=encoder_cfg.feature_transform,
            num_random_features=encoder_cfg.num_random_features,
            redraw=encoder_cfg.redraw,
            is_short_seq=encoder_cfg.is_short_seq,
            begin_kernel=encoder_cfg.begin_kernel,
            scale=encoder_cfg.scale,
        )
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            intermediate_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            norm_first=encoder_cfg.norm_first,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            attention_cls=layers.KernelAttention,
            attention_cfg=attention_cfg)
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cls=layers.TransformerScaffold,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            mask_cls=layers.KernelMask,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=False,
            dict_outputs=True,
            layer_idx_as_attention_seed=True)
        return networks.EncoderScaffold(**kwargs)

    if encoder_type == "xlnet":
        return networks.XLNetBase(
            vocab_size=encoder_cfg.vocab_size,
            num_layers=encoder_cfg.num_layers,
            hidden_size=encoder_cfg.hidden_size,
            num_attention_heads=encoder_cfg.num_attention_heads,
            head_size=encoder_cfg.head_size,
            inner_size=encoder_cfg.inner_size,
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            attention_type=encoder_cfg.attention_type,
            bi_data=encoder_cfg.bi_data,
            two_stream=encoder_cfg.two_stream,
            tie_attention_biases=encoder_cfg.tie_attention_biases,
            memory_length=encoder_cfg.memory_length,
            clamp_length=encoder_cfg.clamp_length,
            reuse_length=encoder_cfg.reuse_length,
            inner_activation=encoder_cfg.inner_activation,
            use_cls_mask=encoder_cfg.use_cls_mask,
            embedding_width=encoder_cfg.embedding_width,
            initializer=tf.keras.initializers.RandomNormal(
                stddev=encoder_cfg.initializer_range))

    if encoder_type == "reuse":
        embedding_cfg = dict(
            vocab_size=encoder_cfg.vocab_size,
            type_vocab_size=encoder_cfg.type_vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            max_seq_length=encoder_cfg.max_position_embeddings,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            dropout_rate=encoder_cfg.dropout_rate)
        hidden_cfg = dict(
            num_attention_heads=encoder_cfg.num_attention_heads,
            inner_dim=encoder_cfg.intermediate_size,
            inner_activation=tf_utils.get_activation(
                encoder_cfg.hidden_activation),
            output_dropout=encoder_cfg.dropout_rate,
            attention_dropout=encoder_cfg.attention_dropout_rate,
            norm_first=encoder_cfg.norm_first,
            kernel_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            reuse_attention=encoder_cfg.reuse_attention,
            use_relative_pe=encoder_cfg.use_relative_pe,
            pe_max_seq_length=encoder_cfg.pe_max_seq_length,
            max_reuse_layer_idx=encoder_cfg.max_reuse_layer_idx)
        kwargs = dict(
            embedding_cfg=embedding_cfg,
            hidden_cls=layers.ReuseTransformer,
            hidden_cfg=hidden_cfg,
            num_hidden_instances=encoder_cfg.num_layers,
            pooled_output_dim=encoder_cfg.hidden_size,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            return_all_layer_outputs=False,
            dict_outputs=True,
            feed_layer_idx=True,
            recursive=True)
        return networks.EncoderScaffold(**kwargs)

    if encoder_type == "query_bert":
        embedding_layer = layers.FactorizedEmbedding(
            vocab_size=encoder_cfg.vocab_size,
            embedding_width=encoder_cfg.embedding_size,
            output_dim=encoder_cfg.hidden_size,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            name="word_embeddings")
        return networks.BertEncoderV2(
            vocab_size=encoder_cfg.vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            num_layers=encoder_cfg.num_layers,
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            max_sequence_length=encoder_cfg.max_position_embeddings,
            type_vocab_size=encoder_cfg.type_vocab_size,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range),
            output_range=encoder_cfg.output_range,
            embedding_layer=embedding_layer,
            return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
            dict_outputs=True,
            norm_first=encoder_cfg.norm_first)

    bert_encoder_cls = networks.BertEncoder
    if encoder_type == "bert_v2":
        bert_encoder_cls = networks.BertEncoderV2

    # Uses the default BERTEncoder configuration schema to create the encoder.
    # If it does not match, please add a switch branch by the encoder type.
    return bert_encoder_cls(
        vocab_size=encoder_cfg.vocab_size,
        hidden_size=encoder_cfg.hidden_size,
        num_layers=encoder_cfg.num_layers,
        num_attention_heads=encoder_cfg.num_attention_heads,
        intermediate_size=encoder_cfg.intermediate_size,
        activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
        dropout_rate=encoder_cfg.dropout_rate,
        attention_dropout_rate=encoder_cfg.attention_dropout_rate,
        max_sequence_length=encoder_cfg.max_position_embeddings,
        type_vocab_size=encoder_cfg.type_vocab_size,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=encoder_cfg.initializer_range),
        output_range=encoder_cfg.output_range,
        embedding_width=encoder_cfg.embedding_size,
        embedding_layer=embedding_layer,
        return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
        dict_outputs=True,
        norm_first=encoder_cfg.norm_first)