Ejemplo n.º 1
0
  def test_sentence_prediction(self):
    config = sentence_prediction.SentencePredictionConfig(
        model=sentence_prediction.ModelConfig(
            encoder=encoders.EncoderConfig(
                bert=encoders.BertEncoderConfig(vocab_size=30522,
                                                num_layers=1)),
            num_classes=2))
    task = sentence_prediction.SentencePredictionTask(config)
    model = task.build_model()
    ckpt = tf.train.Checkpoint(model=model)
    ckpt_path = ckpt.save(self.get_temp_dir())
    export_module_cls = export_savedmodel.lookup_export_module(task)
    serving_params = {"inputs_only": False}
    params = export_module_cls.Params(**serving_params)
    export_module = export_module_cls(params=params, model=model)
    export_dir = export_savedmodel_util.export(
        export_module,
        function_keys=["serve"],
        checkpoint_path=ckpt_path,
        export_savedmodel_dir=self.get_temp_dir())
    imported = tf.saved_model.load(export_dir)
    serving_fn = imported.signatures["serving_default"]

    dummy_ids = tf.ones((1, 5), dtype=tf.int32)
    inputs = dict(
        input_word_ids=dummy_ids,
        input_mask=dummy_ids,
        input_type_ids=dummy_ids)
    ref_outputs = model(inputs)
    outputs = serving_fn(**inputs)
    self.assertAllClose(ref_outputs, outputs["outputs"])
    self.assertEqual(outputs["outputs"].shape, (1, 2))
 def test_sentence_prediction_text(self, inputs_only):
     vocab_file_path = os.path.join(self.get_temp_dir(), "vocab.txt")
     _create_fake_vocab_file(vocab_file_path)
     config = sentence_prediction.SentencePredictionConfig(
         model=sentence_prediction.ModelConfig(
             encoder=encoders.EncoderConfig(bert=encoders.BertEncoderConfig(
                 vocab_size=30522, num_layers=1)),
             num_classes=2))
     task = sentence_prediction.SentencePredictionTask(config)
     model = task.build_model()
     params = serving_modules.SentencePrediction.Params(
         inputs_only=inputs_only,
         parse_sequence_length=10,
         text_fields=["foo", "bar"],
         vocab_file=vocab_file_path)
     export_module = serving_modules.SentencePrediction(params=params,
                                                        model=model)
     examples = _create_fake_serialized_examples({
         "foo": b"hello world",
         "bar": b"hello world"
     })
     functions = export_module.get_inference_signatures({
         "serve_text_examples":
         "serving_default",
     })
     outputs = functions["serving_default"](examples)
     self.assertEqual(outputs["outputs"].shape, (10, 2))
Ejemplo n.º 3
0
def teams_sentence_prediction() -> cfg.ExperimentConfig:
  r"""Teams GLUE."""
  config = cfg.ExperimentConfig(
      task=sentence_prediction.SentencePredictionConfig(
          model=sentence_prediction.ModelConfig(
              encoder=encoders.EncoderConfig(
                  type="any", any=teams.TeamsEncoderConfig(num_layers=1))),
          train_data=sentence_prediction_dataloader
          .SentencePredictionDataConfig(),
          validation_data=sentence_prediction_dataloader
          .SentencePredictionDataConfig(
              is_training=False, drop_remainder=False)),
      trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()),
      restrictions=[
          "task.train_data.is_training != None",
          "task.validation_data.is_training != None"
      ])
  return config
Ejemplo n.º 4
0
def roformer_glue() -> cfg.ExperimentConfig:
  r"""BigBird GLUE."""
  config = cfg.ExperimentConfig(
      task=sentence_prediction.SentencePredictionConfig(
          model=sentence_prediction.ModelConfig(
              encoder=encoders.EncoderConfig(
                  type='any', any=roformer.RoformerEncoderConfig())),
          train_data=sentence_prediction_dataloader
          .SentencePredictionDataConfig(),
          validation_data=sentence_prediction_dataloader
          .SentencePredictionDataConfig(
              is_training=False, drop_remainder=False)),
      trainer=cfg.TrainerConfig(
          optimizer_config=optimization.OptimizationConfig({
              'optimizer': {
                  'type': 'adamw',
                  'adamw': {
                      'weight_decay_rate':
                          0.01,
                      'exclude_from_weight_decay':
                          ['LayerNorm', 'layer_norm', 'bias'],
                  }
              },
              'learning_rate': {
                  'type': 'polynomial',
                  'polynomial': {
                      'initial_learning_rate': 3e-5,
                      'end_learning_rate': 0.0,
                  }
              },
              'warmup': {
                  'type': 'polynomial'
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])
  return config
Ejemplo n.º 5
0
 def get_model_config(self, num_classes):
   return sentence_prediction.ModelConfig(
       encoder=encoders.EncoderConfig(
           bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)),
       num_classes=num_classes)
Ejemplo n.º 6
0
 def get_model_config(self, num_classes):
   return sentence_prediction.ModelConfig(
       encoder=encoders.TransformerEncoderConfig(
           vocab_size=30522, num_layers=1),
       num_classes=num_classes)
    def test_sentence_prediction(self, use_v2_feature_names):
        if use_v2_feature_names:
            input_word_ids_field = "input_word_ids"
            input_type_ids_field = "input_type_ids"
        else:
            input_word_ids_field = "input_ids"
            input_type_ids_field = "segment_ids"

        config = sentence_prediction.SentencePredictionConfig(
            model=sentence_prediction.ModelConfig(
                encoder=encoders.EncoderConfig(bert=encoders.BertEncoderConfig(
                    vocab_size=30522, num_layers=1)),
                num_classes=2))
        task = sentence_prediction.SentencePredictionTask(config)
        model = task.build_model()
        params = serving_modules.SentencePrediction.Params(
            inputs_only=True,
            parse_sequence_length=10,
            use_v2_feature_names=use_v2_feature_names)
        export_module = serving_modules.SentencePrediction(params=params,
                                                           model=model)
        functions = export_module.get_inference_signatures({
            "serve":
            "serving_default",
            "serve_examples":
            "serving_examples"
        })
        self.assertSameElements(functions.keys(),
                                ["serving_default", "serving_examples"])
        dummy_ids = tf.ones((10, 10), dtype=tf.int32)
        outputs = functions["serving_default"](dummy_ids)
        self.assertEqual(outputs["outputs"].shape, (10, 2))

        params = serving_modules.SentencePrediction.Params(
            inputs_only=False,
            parse_sequence_length=10,
            use_v2_feature_names=use_v2_feature_names)
        export_module = serving_modules.SentencePrediction(params=params,
                                                           model=model)
        functions = export_module.get_inference_signatures({
            "serve":
            "serving_default",
            "serve_examples":
            "serving_examples"
        })
        outputs = functions["serving_default"](input_word_ids=dummy_ids,
                                               input_mask=dummy_ids,
                                               input_type_ids=dummy_ids)
        self.assertEqual(outputs["outputs"].shape, (10, 2))

        dummy_ids = tf.ones((10, ), dtype=tf.int32)
        examples = _create_fake_serialized_examples({
            input_word_ids_field:
            dummy_ids,
            "input_mask":
            dummy_ids,
            input_type_ids_field:
            dummy_ids
        })
        outputs = functions["serving_examples"](examples)
        self.assertEqual(outputs["outputs"].shape, (10, 2))

        with self.assertRaises(ValueError):
            _ = export_module.get_inference_signatures({"foo": None})