Ejemplo n.º 1
0
 def build_model(self):
   if self._hub_module:
     encoder_network = utils.get_encoder_from_hub(self._hub_module)
   else:
     encoder_network = encoders.build_encoder(self.task_config.model.encoder)
   encoder_cfg = self.task_config.model.encoder.get()
   # Currently, we only support bert-style sentence prediction finetuning.
   return models.BertClassifier(
       network=encoder_network,
       num_classes=self.task_config.model.num_classes,
       initializer=tf.keras.initializers.TruncatedNormal(
           stddev=encoder_cfg.initializer_range),
       use_encoder_pooler=self.task_config.model.use_encoder_pooler)
Ejemplo n.º 2
0
    def build_model(self) -> tf.keras.Model:
        if self.task_config.hub_module_url and self.task_config.init_checkpoint:
            raise ValueError('At most one of `hub_module_url` and '
                             '`init_checkpoint` can be specified.')
        if self.task_config.hub_module_url:
            encoder_network = utils.get_encoder_from_hub(
                self.task_config.hub_module_url)
        else:
            encoder_network = encoders.build_encoder(
                self.task_config.model.encoder)

        return models.BertClassifier(
            network=encoder_network,
            num_classes=len(self.task_config.class_names),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=self.task_config.model.head_initializer_range),
            dropout_rate=self.task_config.model.head_dropout)
Ejemplo n.º 3
0
 def build_model(self):
     if self.task_config.hub_module_url and self.task_config.init_checkpoint:
         raise ValueError('At most one of `hub_module_url` and '
                          '`init_checkpoint` can be specified.')
     if self.task_config.hub_module_url:
         encoder_network = utils.get_encoder_from_hub(
             self.task_config.hub_module_url)
     else:
         encoder_network = encoders.build_encoder(
             self.task_config.model.encoder)
     encoder_cfg = self.task_config.model.encoder.get()
     # Currently, we only support bert-style sentence prediction finetuning.
     return models.BertClassifier(
         network=encoder_network,
         num_classes=self.task_config.model.num_classes,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=encoder_cfg.initializer_range),
         use_encoder_pooler=self.task_config.model.use_encoder_pooler)
Ejemplo n.º 4
0
 def build_model(self):
   if self.task_config.hub_module_url and self.task_config.init_checkpoint:
     raise ValueError('At most one of `hub_module_url` and '
                      '`init_checkpoint` can be specified.')
   if self.task_config.hub_module_url:
     encoder_network = utils.get_encoder_from_hub(
         self.task_config.hub_module_url)
   else:
     encoder_network = encoders.build_encoder(self.task_config.model.encoder)
   encoder_cfg = self.task_config.model.encoder.get()
   if self.task_config.model.encoder.type == 'xlnet':
     return models.XLNetClassifier(
         network=encoder_network,
         num_classes=self.task_config.model.num_classes,
         initializer=tf.keras.initializers.RandomNormal(
             stddev=encoder_cfg.initializer_range))
   else:
     return models.BertClassifier(
         network=encoder_network,
         num_classes=self.task_config.model.num_classes,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=encoder_cfg.initializer_range),
         use_encoder_pooler=self.task_config.model.use_encoder_pooler)
Ejemplo n.º 5
0
def classifier_model(bert_config,
                     num_labels,
                     max_seq_length=None,
                     final_layer_initializer=None,
                     hub_module_url=None,
                     hub_module_trainable=True):
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
    hub_module_url: TF-Hub path/url to Bert module.
    hub_module_trainable: True to finetune layers in the hub module.

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

  if not hub_module_url:
    bert_encoder = get_transformer_encoder(
        bert_config, max_seq_length, output_range=1)
    return models.BertClassifier(
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
  bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)

  output = tf.keras.layers.Dense(
      num_labels, kernel_initializer=initializer, name='output')(
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model