Exemplo n.º 1
0
    def test_network_invocation(self):
        config = bert.BertPretrainerConfig(
            encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                      num_layers=1))
        _ = bert.instantiate_pretrainer_from_cfg(config)

        # Invokes with classification heads.
        config = bert.BertPretrainerConfig(
            encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                      num_layers=1),
            cls_heads=[
                bert.ClsHeadConfig(inner_dim=10,
                                   num_classes=2,
                                   name="next_sentence")
            ])
        _ = bert.instantiate_pretrainer_from_cfg(config)

        with self.assertRaises(ValueError):
            config = bert.BertPretrainerConfig(
                encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                          num_layers=1),
                cls_heads=[
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=2,
                                       name="next_sentence"),
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=2,
                                       name="next_sentence")
                ])
            _ = bert.instantiate_pretrainer_from_cfg(config)
Exemplo n.º 2
0
  def test_task(self):
    config = sentence_prediction.SentencePredictionConfig(
        init_checkpoint=self.get_temp_dir(),
        model=self.get_model_config(2),
        train_data=self._train_data_config)
    task = sentence_prediction.SentencePredictionTask(config)
    model = task.build_model()
    metrics = task.build_metrics()
    dataset = task.build_inputs(config.train_data)

    iterator = iter(dataset)
    optimizer = tf.keras.optimizers.SGD(lr=0.1)
    task.train_step(next(iterator), model, optimizer, metrics=metrics)
    task.validation_step(next(iterator), model, metrics=metrics)

    # Saves a checkpoint.
    pretrain_cfg = bert.BertPretrainerConfig(
        encoder=encoders.TransformerEncoderConfig(
            vocab_size=30522, num_layers=1),
        cls_heads=[
            bert.ClsHeadConfig(
                inner_dim=10, num_classes=3, name="next_sentence")
        ])
    pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
    ckpt = tf.train.Checkpoint(
        model=pretrain_model, **pretrain_model.checkpoint_items)
    ckpt.save(config.init_checkpoint)
    task.initialize(model)
Exemplo n.º 3
0
 def test_checkpoint_items(self):
     config = bert.BertPretrainerConfig(
         encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                   num_layers=1),
         cls_heads=[
             bert.ClsHeadConfig(inner_dim=10,
                                num_classes=2,
                                name="next_sentence")
         ])
     encoder = bert.instantiate_pretrainer_from_cfg(config)
     self.assertSameElements(
         encoder.checkpoint_items.keys(),
         ["encoder", "masked_lm", "next_sentence.pooler_dense"])
Exemplo n.º 4
0
  def test_task(self, version_2_with_negative, tokenization):
    # Saves a checkpoint.
    pretrain_cfg = bert.BertPretrainerConfig(
        encoder=self._encoder_config,
        cls_heads=[
            bert.ClsHeadConfig(
                inner_dim=10, num_classes=3, name="next_sentence")
        ])
    pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
    ckpt = tf.train.Checkpoint(
        model=pretrain_model, **pretrain_model.checkpoint_items)
    saved_path = ckpt.save(self.get_temp_dir())

    config = question_answering.QuestionAnsweringConfig(
        init_checkpoint=saved_path,
        model=question_answering.ModelConfig(encoder=self._encoder_config),
        train_data=self._train_data_config,
        validation_data=self._get_validation_data_config(
            version_2_with_negative))
    self._run_task(config)
Exemplo n.º 5
0
 def build_model(self, params=None):
   params = params or self.task_config.model
   return bert.instantiate_pretrainer_from_cfg(params)
Exemplo n.º 6
0
 def build_model(self):
     return bert.instantiate_pretrainer_from_cfg(self.task_config.model)