def create_mobilebert_pretrainer(bert_config): """Creates a BertPretrainerV2 that wraps MobileBERTEncoder model.""" mobilebert_encoder = networks.MobileBERTEncoder( word_vocab_size=bert_config.vocab_size, word_embed_size=bert_config.embedding_size, type_vocab_size=bert_config.type_vocab_size, max_sequence_length=bert_config.max_position_embeddings, num_blocks=bert_config.num_hidden_layers, hidden_size=bert_config.hidden_size, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, intermediate_act_fn=tf_utils.get_activation(bert_config.hidden_act), hidden_dropout_prob=bert_config.hidden_dropout_prob, attention_probs_dropout_prob=bert_config.attention_probs_dropout_prob, intra_bottleneck_size=bert_config.intra_bottleneck_size, initializer_range=bert_config.initializer_range, use_bottleneck_attention=bert_config.use_bottleneck_attention, key_query_shared_bottleneck=bert_config.key_query_shared_bottleneck, num_feedforward_networks=bert_config.num_feedforward_networks, normalization_type=bert_config.normalization_type, classifier_activation=bert_config.classifier_activation) masked_lm = layers.MobileBertMaskedLM( embedding_table=mobilebert_encoder.get_embedding_table(), activation=tf_utils.get_activation(bert_config.hidden_act), initializer=tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range), name="cls/predictions") pretrainer = models.BertPretrainerV2(encoder_network=mobilebert_encoder, customized_masked_lm=masked_lm) # Makes sure the pretrainer variables are created. _ = pretrainer(pretrainer.inputs) return pretrainer
def build_encoder(config: EncoderConfig, embedding_layer: Optional[tf.keras.layers.Layer] = None, encoder_cls=None, bypass_config: bool = False): """Instantiate a Transformer encoder network from EncoderConfig. Args: config: the one-of encoder config, which provides encoder parameters of a chosen encoder. embedding_layer: an external embedding layer passed to the encoder. encoder_cls: an external encoder cls not included in the supported encoders, usually used by gin.configurable. bypass_config: whether to ignore config instance to create the object with `encoder_cls`. Returns: An encoder instance. """ if bypass_config: return encoder_cls() encoder_type = config.type encoder_cfg = config.get() if encoder_cls and encoder_cls.__name__ == "EncoderScaffold": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), ) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True) return encoder_cls(**kwargs) if encoder_type == "mobilebert": return networks.MobileBERTEncoder( word_vocab_size=encoder_cfg.word_vocab_size, word_embed_size=encoder_cfg.word_embed_size, type_vocab_size=encoder_cfg.type_vocab_size, max_sequence_length=encoder_cfg.max_sequence_length, num_blocks=encoder_cfg.num_blocks, hidden_size=encoder_cfg.hidden_size, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_act_fn=encoder_cfg.hidden_activation, hidden_dropout_prob=encoder_cfg.hidden_dropout_prob, attention_probs_dropout_prob=encoder_cfg. attention_probs_dropout_prob, intra_bottleneck_size=encoder_cfg.intra_bottleneck_size, initializer_range=encoder_cfg.initializer_range, use_bottleneck_attention=encoder_cfg.use_bottleneck_attention, key_query_shared_bottleneck=encoder_cfg. key_query_shared_bottleneck, num_feedforward_networks=encoder_cfg.num_feedforward_networks, normalization_type=encoder_cfg.normalization_type, classifier_activation=encoder_cfg.classifier_activation, input_mask_dtype=encoder_cfg.input_mask_dtype) if encoder_type == "albert": return networks.AlbertEncoder( vocab_size=encoder_cfg.vocab_size, embedding_width=encoder_cfg.embedding_width, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dict_outputs=True) if encoder_type == "bigbird": # TODO(frederickliu): Support use_gradient_checkpointing and update # experiments to use the EncoderScaffold only. if encoder_cfg.use_gradient_checkpointing: return bigbird_encoder.BigBirdEncoder( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, num_rand_blocks=encoder_cfg.num_rand_blocks, block_size=encoder_cfg.block_size, max_position_embeddings=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), embedding_width=encoder_cfg.embedding_width, use_gradient_checkpointing=encoder_cfg. use_gradient_checkpointing) embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) attention_cfg = dict( num_heads=encoder_cfg.num_attention_heads, key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), max_rand_mask_length=encoder_cfg.max_position_embeddings, num_rand_blocks=encoder_cfg.num_rand_blocks, from_block_size=encoder_cfg.block_size, to_block_size=encoder_cfg.block_size, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), attention_cls=layers.BigBirdAttention, attention_cfg=attention_cfg) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.TransformerScaffold, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, mask_cls=layers.BigBirdMasks, mask_cfg=dict(block_size=encoder_cfg.block_size), pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, layer_idx_as_attention_seed=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "kernel": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) attention_cfg = dict( num_heads=encoder_cfg.num_attention_heads, key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), feature_transform=encoder_cfg.feature_transform, num_random_features=encoder_cfg.num_random_features, redraw=encoder_cfg.redraw, is_short_seq=encoder_cfg.is_short_seq, begin_kernel=encoder_cfg.begin_kernel, scale=encoder_cfg.scale, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), attention_cls=layers.KernelAttention, attention_cfg=attention_cfg) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.TransformerScaffold, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, mask_cls=layers.KernelMask, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, layer_idx_as_attention_seed=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "xlnet": return networks.XLNetBase( vocab_size=encoder_cfg.vocab_size, num_layers=encoder_cfg.num_layers, hidden_size=encoder_cfg.hidden_size, num_attention_heads=encoder_cfg.num_attention_heads, head_size=encoder_cfg.head_size, inner_size=encoder_cfg.inner_size, dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_type=encoder_cfg.attention_type, bi_data=encoder_cfg.bi_data, two_stream=encoder_cfg.two_stream, tie_attention_biases=encoder_cfg.tie_attention_biases, memory_length=encoder_cfg.memory_length, clamp_length=encoder_cfg.clamp_length, reuse_length=encoder_cfg.reuse_length, inner_activation=encoder_cfg.inner_activation, use_cls_mask=encoder_cfg.use_cls_mask, embedding_width=encoder_cfg.embedding_width, initializer=tf.keras.initializers.RandomNormal( stddev=encoder_cfg.initializer_range)) if encoder_type == "teams": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, embedding_width=encoder_cfg.embedding_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate, ) embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), ) kwargs = dict( embedding_cfg=embedding_cfg, embedding_cls=embedding_network, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True) return networks.EncoderScaffold(**kwargs) # Uses the default BERTEncoder configuration schema to create the encoder. # If it does not match, please add a switch branch by the encoder type. return networks.BertEncoder( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_width=encoder_cfg.embedding_size, embedding_layer=embedding_layer, return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True, norm_first=encoder_cfg.norm_first)