def classifier_model( model_config: xlnet_config.XLNetConfig, run_config: xlnet_config.RunConfig, num_labels: int, final_layer_initializer: tf.keras.initializers.Initializer = None ) -> tf.keras.Model: """Returns a TF2 Keras XLNet classifier model. Construct a Keras model for predicting `num_labels` outputs from an input with maximum sequence length `max_seq_length`. Args: model_config: the config that defines the core XLNet model. run_config: separate runtime configuration with extra parameters. num_labels: integer, the number of classes. final_layer_initializer: Initializer for final dense layer. If `None`, then it defaults to the one specified in `run_config`. Returns: Combined prediction model inputs -> (one-hot labels) XLNet sub-model inputs -> (xlnet_outputs) where inputs are: (words, segments, mask, permutation mask, target mapping, masked tokens) """ if final_layer_initializer is not None: initializer = final_layer_initializer else: initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=.02) xlnet_base = get_xlnet_base(model_config=model_config, run_config=run_config, attention_type='bi', two_stream=False, use_cls_mask=False) return models.XLNetClassifier(network=xlnet_base, num_classes=num_labels, dropout_rate=run_config.dropout, summary_type='last', initializer=initializer), xlnet_base
def build_model(self): if self.task_config.hub_module_url and self.task_config.init_checkpoint: raise ValueError('At most one of `hub_module_url` and ' '`init_checkpoint` can be specified.') if self.task_config.hub_module_url: encoder_network = utils.get_encoder_from_hub( self.task_config.hub_module_url) else: encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_cfg = self.task_config.model.encoder.get() if self.task_config.model.encoder.type == 'xlnet': return models.XLNetClassifier( network=encoder_network, num_classes=self.task_config.model.num_classes, initializer=tf.keras.initializers.RandomNormal( stddev=encoder_cfg.initializer_range)) else: return models.BertClassifier( network=encoder_network, num_classes=self.task_config.model.num_classes, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), use_encoder_pooler=self.task_config.model.use_encoder_pooler)