def __init__(self, xlnet_config, run_config, start_n_top, end_n_top, **kwargs): super(QAXLNetModel, self).__init__(**kwargs) self.run_config = run_config self.initializer = _get_initializer(run_config) self.xlnet_config = copy.deepcopy(xlnet_config) self.xlnet_model = networks.XLNetBase( vocab_size=self.xlnet_config.n_token, initializer=self.initializer, attention_type="bi", num_layers=self.xlnet_config.n_layer, hidden_size=self.xlnet_config.d_model, num_attention_heads=self.xlnet_config.n_head, head_size=self.xlnet_config.d_head, inner_size=self.xlnet_config.d_inner, tie_attention_biases=not self.xlnet_config.untie_r, inner_activation=self.xlnet_config.ff_activation, dropout_rate=self.run_config.dropout, attention_dropout_rate=self.run_config.dropout_att, two_stream=False, memory_length=self.run_config.mem_len, reuse_length=self.run_config.reuse_len, bi_data=self.run_config.bi_data, clamp_length=self.run_config.clamp_len, use_cls_mask=False, name="xlnet_model") self.qa_loss_layer = QALossLayer( hidden_size=self.xlnet_config.d_model, start_n_top=start_n_top, end_n_top=end_n_top, initializer=self.initializer, dropout_rate=self.run_config.dropout, name="qa_loss_layer")
def __init__(self, xlnet_config, run_config, n_class, summary_type, use_legacy_mask=True, **kwargs): super(ClassificationXLNetModel, self).__init__(**kwargs) warnings.warn( "`ClassificationXLNetModel` is deprecated, please use `XLNetClassifier`" "instead.", DeprecationWarning, stacklevel=2) self.run_config = run_config self.initializer = _get_initializer(run_config) self.xlnet_config = copy.deepcopy(xlnet_config) self._use_legacy_mask = use_legacy_mask self.xlnet_model = networks.XLNetBase( vocab_size=self.xlnet_config.n_token, initializer=self.initializer, attention_type="bi", num_layers=self.xlnet_config.n_layer, hidden_size=self.xlnet_config.d_model, num_attention_heads=self.xlnet_config.n_head, head_size=self.xlnet_config.d_head, inner_size=self.xlnet_config.d_inner, two_stream=False, tie_attention_biases=not self.xlnet_config.untie_r, inner_activation=self.xlnet_config.ff_activation, dropout_rate=self.run_config.dropout, attention_dropout_rate=self.run_config.dropout_att, memory_length=self.run_config.mem_len, reuse_length=self.run_config.reuse_len, bi_data=self.run_config.bi_data, clamp_length=self.run_config.clamp_len, use_cls_mask=False, name="xlnet_model") self.summarization_layer = Summarization( hidden_size=self.xlnet_config.d_model, num_attention_heads=self.xlnet_config.n_head, head_size=self.xlnet_config.d_head, dropout_rate=self.run_config.dropout, attention_dropout_rate=self.run_config.dropout_att, initializer=self.initializer, use_proj=True, summary_type=summary_type, name="sequence_summary") self.cl_loss_layer = ClassificationLossLayer( n_class=n_class, initializer=self.initializer, name="classification")
def _get_xlnet_base() -> tf.keras.layers.Layer: """Returns a trivial base XLNet model.""" return networks.XLNetBase( vocab_size=100, num_layers=2, hidden_size=4, num_attention_heads=2, head_size=2, inner_size=2, dropout_rate=0., attention_dropout_rate=0., attention_type='bi', bi_data=True, initializer=tf.keras.initializers.RandomNormal(stddev=0.1), two_stream=False, tie_attention_biases=True, reuse_length=0, inner_activation='relu')
def get_xlnet_base(model_config: xlnet_config.XLNetConfig, run_config: xlnet_config.RunConfig, attention_type: str, two_stream: bool, use_cls_mask: bool) -> tf.keras.Model: """Gets an 'XLNetBase' object. Args: model_config: the config that defines the core XLNet model. run_config: separate runtime configuration with extra parameters. attention_type: the attention type for the base XLNet model, "uni" or "bi". two_stream: whether or not to use two strema attention. use_cls_mask: whether or not cls mask is included in the input sequences. Returns: An XLNetBase object. """ initializer = _get_initializer( initialization_method=run_config.init_method, initialization_range=run_config.init_range, initialization_std=run_config.init_std) kwargs = dict(vocab_size=model_config.n_token, num_layers=model_config.n_layer, hidden_size=model_config.d_model, num_attention_heads=model_config.n_head, head_size=model_config.d_head, inner_size=model_config.d_inner, dropout_rate=run_config.dropout, attention_dropout_rate=run_config.dropout_att, attention_type=attention_type, bi_data=run_config.bi_data, initializer=initializer, two_stream=two_stream, tie_attention_biases=not model_config.untie_r, memory_length=run_config.mem_len, clamp_length=run_config.clamp_len, reuse_length=run_config.reuse_len, inner_activation=model_config.ff_activation, use_cls_mask=use_cls_mask) return networks.XLNetBase(**kwargs)
def __init__(self, use_proj, xlnet_config, run_config, use_legacy_mask=True, **kwargs): super(PretrainingXLNetModel, self).__init__(**kwargs) self.run_config = run_config self.initializer = _get_initializer(run_config) self.xlnet_config = copy.deepcopy(xlnet_config) self._use_legacy_mask = use_legacy_mask self.xlnet_model = networks.XLNetBase( vocab_size=self.xlnet_config.n_token, initializer=self.initializer, attention_type="bi", num_layers=self.xlnet_config.n_layer, hidden_size=self.xlnet_config.d_model, num_attention_heads=self.xlnet_config.n_head, head_size=self.xlnet_config.d_head, inner_size=self.xlnet_config.d_inner, two_stream=True, tie_attention_biases=not self.xlnet_config.untie_r, inner_activation=self.xlnet_config.ff_activation, dropout_rate=self.run_config.dropout, attention_dropout_rate=self.run_config.dropout_att, memory_length=self.run_config.mem_len, reuse_length=self.run_config.reuse_len, bi_data=self.run_config.bi_data, clamp_length=self.run_config.clamp_len, use_cls_mask=self.run_config.use_cls_mask, name="xlnet_model") self.lmloss_layer = LMLossLayer( vocab_size=self.xlnet_config.n_token, hidden_size=self.xlnet_config.d_model, initializer=self.initializer, tie_weight=True, bi_data=self.run_config.bi_data, use_one_hot=self.run_config.use_tpu, use_proj=use_proj, name="lm_loss")
def build_encoder(config: EncoderConfig, embedding_layer: Optional[tf.keras.layers.Layer] = None, encoder_cls=None, bypass_config: bool = False): """Instantiate a Transformer encoder network from EncoderConfig. Args: config: the one-of encoder config, which provides encoder parameters of a chosen encoder. embedding_layer: an external embedding layer passed to the encoder. encoder_cls: an external encoder cls not included in the supported encoders, usually used by gin.configurable. bypass_config: whether to ignore config instance to create the object with `encoder_cls`. Returns: An encoder instance. """ if bypass_config: return encoder_cls() encoder_type = config.type encoder_cfg = config.get() if encoder_cls and encoder_cls.__name__ == "EncoderScaffold": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), ) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True) return encoder_cls(**kwargs) if encoder_type == "mobilebert": return networks.MobileBERTEncoder( word_vocab_size=encoder_cfg.word_vocab_size, word_embed_size=encoder_cfg.word_embed_size, type_vocab_size=encoder_cfg.type_vocab_size, max_sequence_length=encoder_cfg.max_sequence_length, num_blocks=encoder_cfg.num_blocks, hidden_size=encoder_cfg.hidden_size, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_act_fn=encoder_cfg.hidden_activation, hidden_dropout_prob=encoder_cfg.hidden_dropout_prob, attention_probs_dropout_prob=encoder_cfg. attention_probs_dropout_prob, intra_bottleneck_size=encoder_cfg.intra_bottleneck_size, initializer_range=encoder_cfg.initializer_range, use_bottleneck_attention=encoder_cfg.use_bottleneck_attention, key_query_shared_bottleneck=encoder_cfg. key_query_shared_bottleneck, num_feedforward_networks=encoder_cfg.num_feedforward_networks, normalization_type=encoder_cfg.normalization_type, classifier_activation=encoder_cfg.classifier_activation, input_mask_dtype=encoder_cfg.input_mask_dtype) if encoder_type == "albert": return networks.AlbertEncoder( vocab_size=encoder_cfg.vocab_size, embedding_width=encoder_cfg.embedding_width, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dict_outputs=True) if encoder_type == "bigbird": # TODO(frederickliu): Support use_gradient_checkpointing and update # experiments to use the EncoderScaffold only. if encoder_cfg.use_gradient_checkpointing: return bigbird_encoder.BigBirdEncoder( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, num_rand_blocks=encoder_cfg.num_rand_blocks, block_size=encoder_cfg.block_size, max_position_embeddings=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), embedding_width=encoder_cfg.embedding_width, use_gradient_checkpointing=encoder_cfg. use_gradient_checkpointing) embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) attention_cfg = dict( num_heads=encoder_cfg.num_attention_heads, key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), max_rand_mask_length=encoder_cfg.max_position_embeddings, num_rand_blocks=encoder_cfg.num_rand_blocks, from_block_size=encoder_cfg.block_size, to_block_size=encoder_cfg.block_size, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), attention_cls=layers.BigBirdAttention, attention_cfg=attention_cfg) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.TransformerScaffold, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, mask_cls=layers.BigBirdMasks, mask_cfg=dict(block_size=encoder_cfg.block_size), pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, layer_idx_as_attention_seed=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "kernel": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) attention_cfg = dict( num_heads=encoder_cfg.num_attention_heads, key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), feature_transform=encoder_cfg.feature_transform, num_random_features=encoder_cfg.num_random_features, redraw=encoder_cfg.redraw, is_short_seq=encoder_cfg.is_short_seq, begin_kernel=encoder_cfg.begin_kernel, scale=encoder_cfg.scale, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), attention_cls=layers.KernelAttention, attention_cfg=attention_cfg) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.TransformerScaffold, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, mask_cls=layers.KernelMask, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, layer_idx_as_attention_seed=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "xlnet": return networks.XLNetBase( vocab_size=encoder_cfg.vocab_size, num_layers=encoder_cfg.num_layers, hidden_size=encoder_cfg.hidden_size, num_attention_heads=encoder_cfg.num_attention_heads, head_size=encoder_cfg.head_size, inner_size=encoder_cfg.inner_size, dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_type=encoder_cfg.attention_type, bi_data=encoder_cfg.bi_data, two_stream=encoder_cfg.two_stream, tie_attention_biases=encoder_cfg.tie_attention_biases, memory_length=encoder_cfg.memory_length, clamp_length=encoder_cfg.clamp_length, reuse_length=encoder_cfg.reuse_length, inner_activation=encoder_cfg.inner_activation, use_cls_mask=encoder_cfg.use_cls_mask, embedding_width=encoder_cfg.embedding_width, initializer=tf.keras.initializers.RandomNormal( stddev=encoder_cfg.initializer_range)) if encoder_type == "teams": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, embedding_width=encoder_cfg.embedding_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate, ) embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), ) kwargs = dict( embedding_cfg=embedding_cfg, embedding_cls=embedding_network, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True) return networks.EncoderScaffold(**kwargs) # Uses the default BERTEncoder configuration schema to create the encoder. # If it does not match, please add a switch branch by the encoder type. return networks.BertEncoder( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_width=encoder_cfg.embedding_size, embedding_layer=embedding_layer, return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True, norm_first=encoder_cfg.norm_first)