def test_multiple_cls_outputs(self): """Validate that the Keras object can be created.""" # Build a transformer network to use within the BERT trainer. vocab_size = 100 sequence_length = 512 hidden_size = 48 num_layers = 2 test_network = networks.BertEncoderV2( vocab_size=vocab_size, num_layers=num_layers, hidden_size=hidden_size, max_sequence_length=sequence_length) bert_trainer_model = bert_pretrainer.BertPretrainerV2( encoder_network=test_network, classification_heads=[ layers.MultiClsHeads(inner_dim=5, cls_list=[('foo', 2), ('bar', 3)]) ]) num_token_predictions = 20 # Create a set of 2-dimensional inputs (the first dimension is implicit). inputs = dict(input_word_ids=tf.keras.Input(shape=(sequence_length, ), dtype=tf.int32), input_mask=tf.keras.Input(shape=(sequence_length, ), dtype=tf.int32), input_type_ids=tf.keras.Input(shape=(sequence_length, ), dtype=tf.int32), masked_lm_positions=tf.keras.Input( shape=(num_token_predictions, ), dtype=tf.int32)) # Invoke the trainer model on the inputs. This causes the layer to be built. outputs = bert_trainer_model(inputs) self.assertEqual(outputs['foo'].shape.as_list(), [None, 2]) self.assertEqual(outputs['bar'].shape.as_list(), [None, 3])
def test_v2_serialize_deserialize(self): """Validate that the BERT trainer can be serialized and deserialized.""" # Build a transformer network to use within the BERT trainer. test_network = networks.BertEncoderV2(vocab_size=100, num_layers=2) # Create a BERT trainer with the created network. (Note that all the args # are different, so we can catch any serialization mismatches.) bert_trainer_model = bert_pretrainer.BertPretrainerV2( encoder_network=test_network) # Create another BERT trainer via serialization and deserialization. config = bert_trainer_model.get_config() new_bert_trainer_model = bert_pretrainer.BertPretrainerV2.from_config( config) # Validate that the config can be forced to JSON. _ = new_bert_trainer_model.to_json() # If the serialization was successful, the new config should match the old. self.assertAllEqual(bert_trainer_model.get_config(), new_bert_trainer_model.get_config())
def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs, use_customized_masked_lm, has_masked_lm_positions): """Validate that the Keras object can be created.""" # Build a transformer network to use within the BERT trainer. vocab_size = 100 sequence_length = 512 hidden_size = 48 num_layers = 2 test_network = networks.BertEncoderV2( vocab_size=vocab_size, num_layers=num_layers, hidden_size=hidden_size, max_sequence_length=sequence_length) _ = test_network(test_network.inputs) # Create a BERT trainer with the created network. if use_customized_masked_lm: customized_masked_lm = layers.MaskedLM( embedding_table=test_network.get_embedding_table()) else: customized_masked_lm = None bert_trainer_model = bert_pretrainer.BertPretrainerV2( encoder_network=test_network, customized_masked_lm=customized_masked_lm) num_token_predictions = 20 # Create a set of 2-dimensional inputs (the first dimension is implicit). inputs = dict(input_word_ids=tf.keras.Input(shape=(sequence_length, ), dtype=tf.int32), input_mask=tf.keras.Input(shape=(sequence_length, ), dtype=tf.int32), input_type_ids=tf.keras.Input(shape=(sequence_length, ), dtype=tf.int32)) if has_masked_lm_positions: inputs['masked_lm_positions'] = tf.keras.Input( shape=(num_token_predictions, ), dtype=tf.int32) # Invoke the trainer model on the inputs. This causes the layer to be built. outputs = bert_trainer_model(inputs) has_encoder_outputs = True # dict_outputs or return_all_encoder_outputs expected_keys = ['sequence_output', 'pooled_output'] if has_encoder_outputs: expected_keys.append('encoder_outputs') if has_masked_lm_positions: expected_keys.append('mlm_logits') self.assertSameElements(outputs.keys(), expected_keys) # Validate that the outputs are of the expected shape. expected_lm_shape = [None, num_token_predictions, vocab_size] if has_masked_lm_positions: self.assertAllEqual(expected_lm_shape, outputs['mlm_logits'].shape.as_list()) expected_sequence_output_shape = [None, sequence_length, hidden_size] self.assertAllEqual(expected_sequence_output_shape, outputs['sequence_output'].shape.as_list()) expected_pooled_output_shape = [None, hidden_size] self.assertAllEqual(expected_pooled_output_shape, outputs['pooled_output'].shape.as_list())
def build_encoder(config: EncoderConfig, embedding_layer: Optional[tf.keras.layers.Layer] = None, encoder_cls=None, bypass_config: bool = False): """Instantiate a Transformer encoder network from EncoderConfig. Args: config: the one-of encoder config, which provides encoder parameters of a chosen encoder. embedding_layer: an external embedding layer passed to the encoder. encoder_cls: an external encoder cls not included in the supported encoders, usually used by gin.configurable. bypass_config: whether to ignore config instance to create the object with `encoder_cls`. Returns: An encoder instance. """ if bypass_config: return encoder_cls() encoder_type = config.type encoder_cfg = config.get() if encoder_cls and encoder_cls.__name__ == "EncoderScaffold": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), ) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True) return encoder_cls(**kwargs) if encoder_type == "any": encoder = encoder_cfg.BUILDER(encoder_cfg) if not isinstance(encoder, (tf.Module, tf.keras.Model, tf.keras.layers.Layer)): raise ValueError( "The BUILDER returns an unexpected instance. The " "`build_encoder` should returns a tf.Module, " "tf.keras.Model or tf.keras.layers.Layer. However, " f"we get {encoder.__class__}") return encoder if encoder_type == "mobilebert": return networks.MobileBERTEncoder( word_vocab_size=encoder_cfg.word_vocab_size, word_embed_size=encoder_cfg.word_embed_size, type_vocab_size=encoder_cfg.type_vocab_size, max_sequence_length=encoder_cfg.max_sequence_length, num_blocks=encoder_cfg.num_blocks, hidden_size=encoder_cfg.hidden_size, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_act_fn=encoder_cfg.hidden_activation, hidden_dropout_prob=encoder_cfg.hidden_dropout_prob, attention_probs_dropout_prob=encoder_cfg. attention_probs_dropout_prob, intra_bottleneck_size=encoder_cfg.intra_bottleneck_size, initializer_range=encoder_cfg.initializer_range, use_bottleneck_attention=encoder_cfg.use_bottleneck_attention, key_query_shared_bottleneck=encoder_cfg. key_query_shared_bottleneck, num_feedforward_networks=encoder_cfg.num_feedforward_networks, normalization_type=encoder_cfg.normalization_type, classifier_activation=encoder_cfg.classifier_activation, input_mask_dtype=encoder_cfg.input_mask_dtype) if encoder_type == "albert": return networks.AlbertEncoder( vocab_size=encoder_cfg.vocab_size, embedding_width=encoder_cfg.embedding_width, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dict_outputs=True) if encoder_type == "bigbird": # TODO(frederickliu): Support use_gradient_checkpointing and update # experiments to use the EncoderScaffold only. if encoder_cfg.use_gradient_checkpointing: return bigbird_encoder.BigBirdEncoder( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, num_rand_blocks=encoder_cfg.num_rand_blocks, block_size=encoder_cfg.block_size, max_position_embeddings=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), embedding_width=encoder_cfg.embedding_width, use_gradient_checkpointing=encoder_cfg. use_gradient_checkpointing) embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) attention_cfg = dict( num_heads=encoder_cfg.num_attention_heads, key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), max_rand_mask_length=encoder_cfg.max_position_embeddings, num_rand_blocks=encoder_cfg.num_rand_blocks, from_block_size=encoder_cfg.block_size, to_block_size=encoder_cfg.block_size, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), attention_cls=layers.BigBirdAttention, attention_cfg=attention_cfg) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.TransformerScaffold, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, mask_cls=layers.BigBirdMasks, mask_cfg=dict(block_size=encoder_cfg.block_size), pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, layer_idx_as_attention_seed=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "kernel": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) attention_cfg = dict( num_heads=encoder_cfg.num_attention_heads, key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), feature_transform=encoder_cfg.feature_transform, num_random_features=encoder_cfg.num_random_features, redraw=encoder_cfg.redraw, is_short_seq=encoder_cfg.is_short_seq, begin_kernel=encoder_cfg.begin_kernel, scale=encoder_cfg.scale, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), attention_cls=layers.KernelAttention, attention_cfg=attention_cfg) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.TransformerScaffold, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, mask_cls=layers.KernelMask, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, layer_idx_as_attention_seed=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "xlnet": return networks.XLNetBase( vocab_size=encoder_cfg.vocab_size, num_layers=encoder_cfg.num_layers, hidden_size=encoder_cfg.hidden_size, num_attention_heads=encoder_cfg.num_attention_heads, head_size=encoder_cfg.head_size, inner_size=encoder_cfg.inner_size, dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_type=encoder_cfg.attention_type, bi_data=encoder_cfg.bi_data, two_stream=encoder_cfg.two_stream, tie_attention_biases=encoder_cfg.tie_attention_biases, memory_length=encoder_cfg.memory_length, clamp_length=encoder_cfg.clamp_length, reuse_length=encoder_cfg.reuse_length, inner_activation=encoder_cfg.inner_activation, use_cls_mask=encoder_cfg.use_cls_mask, embedding_width=encoder_cfg.embedding_width, initializer=tf.keras.initializers.RandomNormal( stddev=encoder_cfg.initializer_range)) if encoder_type == "reuse": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, inner_dim=encoder_cfg.intermediate_size, inner_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), output_dropout=encoder_cfg.dropout_rate, attention_dropout=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), reuse_attention=encoder_cfg.reuse_attention, use_relative_pe=encoder_cfg.use_relative_pe, pe_max_seq_length=encoder_cfg.pe_max_seq_length, max_reuse_layer_idx=encoder_cfg.max_reuse_layer_idx) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.ReuseTransformer, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, feed_layer_idx=True, recursive=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "query_bert": embedding_layer = layers.FactorizedEmbedding( vocab_size=encoder_cfg.vocab_size, embedding_width=encoder_cfg.embedding_size, output_dim=encoder_cfg.hidden_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), name="word_embeddings") return networks.BertEncoderV2( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_layer=embedding_layer, return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True, norm_first=encoder_cfg.norm_first) bert_encoder_cls = networks.BertEncoder if encoder_type == "bert_v2": bert_encoder_cls = networks.BertEncoderV2 # Uses the default BERTEncoder configuration schema to create the encoder. # If it does not match, please add a switch branch by the encoder type. return bert_encoder_cls( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_width=encoder_cfg.embedding_size, embedding_layer=embedding_layer, return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True, norm_first=encoder_cfg.norm_first)