Exemplo n.º 1
0
def instantiate_pretrainer_from_cfg(
    config: ELECTRAPretrainerConfig,
    generator_network: Optional[tf.keras.Model] = None,
    discriminator_network: Optional[tf.keras.Model] = None,
    ) -> electra_pretrainer.ElectraPretrainer:
  """Instantiates ElectraPretrainer from the config."""
  generator_encoder_cfg = config.generator_encoder
  discriminator_encoder_cfg = config.discriminator_encoder
  if generator_network is None:
    generator_network = encoders.instantiate_encoder_from_cfg(
        generator_encoder_cfg)
  if discriminator_network is None:
    discriminator_network = encoders.instantiate_encoder_from_cfg(
        discriminator_encoder_cfg)
  return electra_pretrainer.ElectraPretrainer(
      generator_network=generator_network,
      discriminator_network=discriminator_network,
      vocab_size=config.generator_encoder.vocab_size,
      num_classes=config.num_classes,
      sequence_length=config.sequence_length,
      last_hidden_dim=config.generator_encoder.hidden_size,
      num_token_predictions=config.num_masked_tokens,
      mlm_activation=tf_utils.get_activation(
          generator_encoder_cfg.hidden_activation),
      mlm_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=generator_encoder_cfg.initializer_range),
      classification_heads=instantiate_classification_heads_from_cfgs(
          config.cls_heads))
Exemplo n.º 2
0
def instantiate_pretrainer_from_cfg(
    config: ELECTRAPretrainerConfig,
    generator_network: Optional[tf.keras.Model] = None,
    discriminator_network: Optional[tf.keras.Model] = None,
    ) -> electra_pretrainer.ElectraPretrainer:
  """Instantiates ElectraPretrainer from the config."""
  generator_encoder_cfg = config.generator_encoder
  discriminator_encoder_cfg = config.discriminator_encoder
  # Copy discriminator's embeddings to generator for easier model serialization.
  if discriminator_network is None:
    discriminator_network = encoders.instantiate_encoder_from_cfg(
        discriminator_encoder_cfg)
  if generator_network is None:
    if config.tie_embeddings:
      embedding_layer = discriminator_network.get_embedding_layer()
      generator_network = encoders.instantiate_encoder_from_cfg(
          generator_encoder_cfg, embedding_layer=embedding_layer)
    else:
      generator_network = encoders.instantiate_encoder_from_cfg(
          generator_encoder_cfg)

  return electra_pretrainer.ElectraPretrainer(
      generator_network=generator_network,
      discriminator_network=discriminator_network,
      vocab_size=config.generator_encoder.vocab_size,
      num_classes=config.num_classes,
      sequence_length=config.sequence_length,
      num_token_predictions=config.num_masked_tokens,
      mlm_activation=tf_utils.get_activation(
          generator_encoder_cfg.hidden_activation),
      mlm_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=generator_encoder_cfg.initializer_range),
      classification_heads=instantiate_classification_heads_from_cfgs(
          config.cls_heads),
      disallow_correct=config.disallow_correct)
    def build_model(self):
        if self._hub_module:
            # TODO(lehou): maybe add the hub_module building logic to a util function.
            input_word_ids = tf.keras.layers.Input(shape=(None, ),
                                                   dtype=tf.int32,
                                                   name='input_word_ids')
            input_mask = tf.keras.layers.Input(shape=(None, ),
                                               dtype=tf.int32,
                                               name='input_mask')
            input_type_ids = tf.keras.layers.Input(shape=(None, ),
                                                   dtype=tf.int32,
                                                   name='input_type_ids')
            bert_model = hub.KerasLayer(self._hub_module, trainable=True)
            pooled_output, sequence_output = bert_model(
                [input_word_ids, input_mask, input_type_ids])
            encoder_network = tf.keras.Model(
                inputs=[input_word_ids, input_mask, input_type_ids],
                outputs=[sequence_output, pooled_output])
        else:
            encoder_network = encoders.instantiate_encoder_from_cfg(
                self.task_config.network)

        return models.BertSpanLabeler(
            network=encoder_network,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=self.task_config.network.initializer_range))
Exemplo n.º 4
0
  def build_model(self):
    if self._hub_module:
      encoder_network = utils.get_encoder_from_hub(self._hub_module)
    else:
      encoder_network = encoders.instantiate_encoder_from_cfg(
          self.task_config.network)

    return models.BertSpanLabeler(
        network=encoder_network,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=self.task_config.network.initializer_range))
Exemplo n.º 5
0
 def build_model(self):
     if self._hub_module:
         encoder_network = utils.get_encoder_from_hub(self._hub_module)
     else:
         encoder_network = encoders.instantiate_encoder_from_cfg(
             self.task_config.model.encoder)
     # Currently, we only supports bert-style question answering finetuning.
     return models.BertSpanLabeler(
         network=encoder_network,
         initializer=tf.keras.initializers.TruncatedNormal(
             stddev=self.task_config.model.encoder.initializer_range))
Exemplo n.º 6
0
    def build_model(self):
        if self._hub_module:
            encoder_network = utils.get_encoder_from_hub(self._hub_module)
        else:
            encoder_network = encoders.instantiate_encoder_from_cfg(
                self.task_config.model)

        return models.BertTokenClassifier(
            network=encoder_network,
            num_classes=len(self.task_config.class_names),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=self.task_config.model.initializer_range),
            dropout_rate=self.task_config.model.dropout_rate,
            output='logits')
Exemplo n.º 7
0
  def build_model(self):
    if self._hub_module:
      encoder_network = utils.get_encoder_from_hub(self._hub_module)
    else:
      encoder_network = encoders.instantiate_encoder_from_cfg(
          self.task_config.model.encoder)

    # Currently, we only support bert-style sentence prediction finetuning.
    return models.BertClassifier(
        network=encoder_network,
        num_classes=self.task_config.model.num_classes,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=self.task_config.model.encoder.initializer_range),
        use_encoder_pooler=self.task_config.model.use_encoder_pooler)
Exemplo n.º 8
0
def instantiate_pretrainer_from_cfg(
    config: BertPretrainerConfig,
    encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
  """Instantiates a BertPretrainer from the config."""
  encoder_cfg = config.encoder
  if encoder_network is None:
    encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
  return bert_pretrainer.BertPretrainerV2(
      mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
      mlm_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=encoder_cfg.initializer_range),
      encoder_network=encoder_network,
      classification_heads=instantiate_classification_heads_from_cfgs(
          config.cls_heads))
Exemplo n.º 9
0
    def test_task(self):
        # Saves a checkpoint.
        encoder = encoders.instantiate_encoder_from_cfg(self._encoder_config)
        ckpt = tf.train.Checkpoint(encoder=encoder)
        saved_path = ckpt.save(self.get_temp_dir())

        config = tagging.TaggingConfig(init_checkpoint=saved_path,
                                       model=self._encoder_config,
                                       train_data=self._train_data_config,
                                       class_names=["O", "B-PER", "I-PER"])
        task = tagging.TaggingTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)
        task.initialize(model)