示例#1
0
    def test_task(self):
        # Saves a checkpoint.
        pretrain_cfg = bert.BertPretrainerConfig(encoder=self._encoder_config,
                                                 num_masked_tokens=20,
                                                 cls_heads=[
                                                     bert.ClsHeadConfig(
                                                         inner_dim=10,
                                                         num_classes=3,
                                                         name="next_sentence")
                                                 ])
        pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
        ckpt = tf.train.Checkpoint(model=pretrain_model,
                                   **pretrain_model.checkpoint_items)
        saved_path = ckpt.save(self.get_temp_dir())

        config = question_answering.QuestionAnsweringConfig(
            init_checkpoint=saved_path,
            network=self._encoder_config,
            train_data=self._train_data_config)
        task = question_answering.QuestionAnsweringTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)
        task.initialize(model)
示例#2
0
def bert_squad() -> cfg.ExperimentConfig:
    """BERT Squad V1/V2."""
    config = cfg.ExperimentConfig(
        task=question_answering.QuestionAnsweringConfig(
            train_data=question_answering_dataloader.QADataConfig(),
            validation_data=question_answering_dataloader.QADataConfig()),
        trainer=cfg.TrainerConfig(
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'adamw',
                    'adamw': {
                        'weight_decay_rate':
                        0.01,
                        'exclude_from_weight_decay':
                        ['LayerNorm', 'layer_norm', 'bias'],
                    }
                },
                'learning_rate': {
                    'type': 'polynomial',
                    'polynomial': {
                        'initial_learning_rate': 8e-5,
                        'end_learning_rate': 0.0,
                    }
                },
                'warmup': {
                    'type': 'polynomial'
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])
    config.task.model.encoder.type = 'bert'
    return config
示例#3
0
 def test_task_with_hub(self):
     hub_module_url = self._export_bert_tfhub()
     config = question_answering.QuestionAnsweringConfig(
         hub_module_url=hub_module_url,
         network=self._encoder_config,
         train_data=self._train_data_config)
     self._run_task(config)
 def test_task_with_hub(self):
     hub_module_url = self._export_bert_tfhub()
     config = question_answering.QuestionAnsweringConfig(
         hub_module_url=hub_module_url,
         model=question_answering.ModelConfig(encoder=self._encoder_config),
         train_data=self._train_data_config,
         validation_data=self._get_validation_data_config())
     self._run_task(config)
示例#5
0
def teams_squad() -> cfg.ExperimentConfig:
  """Teams Squad V1/V2."""
  config = cfg.ExperimentConfig(
      task=question_answering.QuestionAnsweringConfig(
          model=question_answering.ModelConfig(
              encoder=encoders.EncoderConfig(
                  type="any", any=teams.TeamsEncoderConfig(num_layers=1))),
          train_data=question_answering_dataloader.QADataConfig(),
          validation_data=question_answering_dataloader.QADataConfig()),
      trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()),
      restrictions=[
          "task.train_data.is_training != None",
          "task.validation_data.is_training != None"
      ])
  return config
示例#6
0
 def test_task_with_fit(self):
     config = question_answering.QuestionAnsweringConfig(
         network=self._encoder_config, train_data=self._train_data_config)
     task = question_answering.QuestionAnsweringTask(config)
     model = task.build_model()
     model = task.compile_model(
         model,
         optimizer=tf.keras.optimizers.SGD(lr=0.1),
         train_step=task.train_step,
         metrics=[
             tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")
         ])
     dataset = task.build_inputs(config.train_data)
     logs = model.fit(dataset, epochs=1, steps_per_epoch=2)
     self.assertIn("loss", logs.history)
     self.assertIn("start_positions_accuracy", logs.history)
     self.assertIn("end_positions_accuracy", logs.history)
    def test_question_answering(self, use_v2_feature_names):
        if use_v2_feature_names:
            input_word_ids_field = "input_word_ids"
            input_type_ids_field = "input_type_ids"
        else:
            input_word_ids_field = "input_ids"
            input_type_ids_field = "segment_ids"

        config = question_answering.QuestionAnsweringConfig(
            model=question_answering.ModelConfig(
                encoder=encoders.EncoderConfig(bert=encoders.BertEncoderConfig(
                    vocab_size=30522, num_layers=1))),
            validation_data=None)
        task = question_answering.QuestionAnsweringTask(config)
        model = task.build_model()
        params = serving_modules.QuestionAnswering.Params(
            parse_sequence_length=10,
            use_v2_feature_names=use_v2_feature_names)
        export_module = serving_modules.QuestionAnswering(params=params,
                                                          model=model)
        functions = export_module.get_inference_signatures({
            "serve":
            "serving_default",
            "serve_examples":
            "serving_examples"
        })
        self.assertSameElements(functions.keys(),
                                ["serving_default", "serving_examples"])
        dummy_ids = tf.ones((10, 10), dtype=tf.int32)
        outputs = functions["serving_default"](input_word_ids=dummy_ids,
                                               input_mask=dummy_ids,
                                               input_type_ids=dummy_ids)
        self.assertEqual(outputs["start_logits"].shape, (10, 10))
        self.assertEqual(outputs["end_logits"].shape, (10, 10))
        dummy_ids = tf.ones((10, ), dtype=tf.int32)
        examples = _create_fake_serialized_examples({
            input_word_ids_field:
            dummy_ids,
            "input_mask":
            dummy_ids,
            input_type_ids_field:
            dummy_ids
        })
        outputs = functions["serving_examples"](examples)
        self.assertEqual(outputs["start_logits"].shape, (10, 10))
        self.assertEqual(outputs["end_logits"].shape, (10, 10))
    def test_predict(self, version_2_with_negative):
        validation_data = self._get_validation_data_config(
            version_2_with_negative=version_2_with_negative)

        config = question_answering.QuestionAnsweringConfig(
            model=question_answering.ModelConfig(encoder=self._encoder_config),
            train_data=self._train_data_config,
            validation_data=validation_data)
        task = question_answering.QuestionAnsweringTask(config)
        model = task.build_model()

        all_predictions, all_nbest, scores_diff = question_answering.predict(
            task, validation_data, model)
        self.assertLen(all_predictions, 1)
        self.assertLen(all_nbest, 1)
        if version_2_with_negative:
            self.assertLen(scores_diff, 1)
        else:
            self.assertEmpty(scores_diff)
  def test_task(self, version_2_with_negative, tokenization):
    # Saves a checkpoint.
    pretrain_cfg = bert.PretrainerConfig(
        encoder=self._encoder_config,
        cls_heads=[
            bert.ClsHeadConfig(
                inner_dim=10, num_classes=3, name="next_sentence")
        ])
    pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
    ckpt = tf.train.Checkpoint(
        model=pretrain_model, **pretrain_model.checkpoint_items)
    saved_path = ckpt.save(self.get_temp_dir())

    config = question_answering.QuestionAnsweringConfig(
        init_checkpoint=saved_path,
        model=question_answering.ModelConfig(encoder=self._encoder_config),
        train_data=self._train_data_config,
        validation_data=self._get_validation_data_config(
            version_2_with_negative))
    self._run_task(config)