コード例 #1
0
def classifier_model(
    model_config: xlnet_config.XLNetConfig,
    run_config: xlnet_config.RunConfig,
    num_labels: int,
    final_layer_initializer: tf.keras.initializers.Initializer = None
) -> tf.keras.Model:
    """Returns a TF2 Keras XLNet classifier model.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
    model_config: the config that defines the core XLNet model.
    run_config: separate runtime configuration with extra parameters.
    num_labels: integer, the number of classes.
    final_layer_initializer: Initializer for final dense layer. If `None`, then
      it defaults to the one specified in `run_config`.

  Returns:
    Combined prediction model inputs -> (one-hot labels)
    XLNet sub-model inputs -> (xlnet_outputs)
    where inputs are:
      (words, segments, mask, permutation mask,
       target mapping, masked tokens)
  """
    if final_layer_initializer is not None:
        initializer = final_layer_initializer
    else:
        initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=.02)
    xlnet_base = get_xlnet_base(model_config=model_config,
                                run_config=run_config,
                                attention_type='bi',
                                two_stream=False,
                                use_cls_mask=False)
    return models.XLNetClassifier(network=xlnet_base,
                                  num_classes=num_labels,
                                  dropout_rate=run_config.dropout,
                                  summary_type='last',
                                  initializer=initializer), xlnet_base
コード例 #2
0
 def build_model(self):
   if self.task_config.hub_module_url and self.task_config.init_checkpoint:
     raise ValueError('At most one of `hub_module_url` and '
                      '`init_checkpoint` can be specified.')
   if self.task_config.hub_module_url:
     encoder_network = utils.get_encoder_from_hub(
         self.task_config.hub_module_url)
   else:
     encoder_network = encoders.build_encoder(self.task_config.model.encoder)
   encoder_cfg = self.task_config.model.encoder.get()
   if self.task_config.model.encoder.type == 'xlnet':
     return models.XLNetClassifier(
         network=encoder_network,
         num_classes=self.task_config.model.num_classes,
         initializer=tf.keras.initializers.RandomNormal(
             stddev=encoder_cfg.initializer_range))
   else:
     return models.BertClassifier(
         network=encoder_network,
         num_classes=self.task_config.model.num_classes,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=encoder_cfg.initializer_range),
         use_encoder_pooler=self.task_config.model.use_encoder_pooler)