Exemplo n.º 1
0
  def test_sentence_prediction(self):
    config = sentence_prediction.SentencePredictionConfig(
        model=sentence_prediction.ModelConfig(
            encoder=encoders.EncoderConfig(
                bert=encoders.BertEncoderConfig(vocab_size=30522,
                                                num_layers=1)),
            num_classes=2))
    task = sentence_prediction.SentencePredictionTask(config)
    model = task.build_model()
    ckpt = tf.train.Checkpoint(model=model)
    ckpt_path = ckpt.save(self.get_temp_dir())
    export_module_cls = export_savedmodel.lookup_export_module(task)
    serving_params = {"inputs_only": False}
    params = export_module_cls.Params(**serving_params)
    export_module = export_module_cls(params=params, model=model)
    export_dir = export_savedmodel_util.export(
        export_module,
        function_keys=["serve"],
        checkpoint_path=ckpt_path,
        export_savedmodel_dir=self.get_temp_dir())
    imported = tf.saved_model.load(export_dir)
    serving_fn = imported.signatures["serving_default"]

    dummy_ids = tf.ones((1, 5), dtype=tf.int32)
    inputs = dict(
        input_word_ids=dummy_ids,
        input_mask=dummy_ids,
        input_type_ids=dummy_ids)
    ref_outputs = model(inputs)
    outputs = serving_fn(**inputs)
    self.assertAllClose(ref_outputs, outputs["outputs"])
    self.assertEqual(outputs["outputs"].shape, (1, 2))
Exemplo n.º 2
0
def main(_):
    serving_params = yaml.load(hyperparams.nested_csv_str_to_json_str(
        FLAGS.serving_params),
                               Loader=yaml.FullLoader)
    export_module = create_export_module(task_name=FLAGS.task_name,
                                         config_file=FLAGS.config_file,
                                         serving_params=serving_params)
    export_dir = export_savedmodel_util.export(
        export_module,
        function_keys=[FLAGS.function_keys],
        checkpoint_path=FLAGS.checkpoint_path,
        export_savedmodel_dir=FLAGS.export_savedmodel_dir,
        module_key=FLAGS.module_key)

    if FLAGS.convert_tpu:
        # pylint: disable=g-import-not-at-top
        from cloud_tpu.inference_converter import converter_cli
        from cloud_tpu.inference_converter import converter_options_pb2
        tpu_dir = os.path.join(export_dir, "tpu")
        options = converter_options_pb2.ConverterOptions()
        if FLAGS.allowed_batch_size is not None:
            allowed_batch_sizes = sorted(FLAGS.allowed_batch_size)
            options.batch_options.num_batch_threads = 4
            options.batch_options.max_batch_size = allowed_batch_sizes[-1]
            options.batch_options.batch_timeout_micros = 100000
            options.batch_options.allowed_batch_sizes[:] = allowed_batch_sizes
            options.batch_options.max_enqueued_batches = 1000
        converter_cli.ConvertSavedModel(export_dir,
                                        tpu_dir,
                                        function_alias="tpu_candidate",
                                        options=options,
                                        graph_rewrite_only=True)
Exemplo n.º 3
0
  def test_tagging(self, output_encoder_outputs):
    hidden_size = 768
    num_classes = 3
    config = tagging.TaggingConfig(
        model=tagging.ModelConfig(
            encoder=encoders.EncoderConfig(
                bert=encoders.BertEncoderConfig(
                    hidden_size=hidden_size, num_layers=1))),
        class_names=["class_0", "class_1", "class_2"])
    task = tagging.TaggingTask(config)
    model = task.build_model()
    ckpt = tf.train.Checkpoint(model=model)
    ckpt_path = ckpt.save(self.get_temp_dir())
    export_module_cls = export_savedmodel.lookup_export_module(task)
    serving_params = {
        "parse_sequence_length": 10,
    }
    params = export_module_cls.Params(
        **serving_params, output_encoder_outputs=output_encoder_outputs)
    export_module = export_module_cls(params=params, model=model)
    export_dir = export_savedmodel_util.export(
        export_module,
        function_keys={
            "serve": "serving_default",
            "serve_examples": "serving_examples"
        },
        checkpoint_path=ckpt_path,
        export_savedmodel_dir=self.get_temp_dir())
    imported = tf.saved_model.load(export_dir)
    self.assertCountEqual(imported.signatures.keys(),
                          ["serving_default", "serving_examples"])

    serving_fn = imported.signatures["serving_default"]
    dummy_ids = tf.ones((1, 5), dtype=tf.int32)
    inputs = dict(
        input_word_ids=dummy_ids,
        input_mask=dummy_ids,
        input_type_ids=dummy_ids)
    outputs = serving_fn(**inputs)
    self.assertEqual(outputs["logits"].shape, (1, 5, num_classes))
    if output_encoder_outputs:
      self.assertEqual(outputs["encoder_outputs"].shape, (1, 5, hidden_size))
Exemplo n.º 4
0
 def test_masked_lm(self):
   config = masked_lm.MaskedLMConfig(
       model=bert.PretrainerConfig(
           encoder=encoders.EncoderConfig(
               bert=encoders.BertEncoderConfig(vocab_size=30522,
                                               num_layers=1)),
           cls_heads=[
               bert.ClsHeadConfig(inner_dim=10, num_classes=2, name="foo")
           ]))
   task = masked_lm.MaskedLMTask(config)
   model = task.build_model()
   ckpt = tf.train.Checkpoint(model=model)
   ckpt_path = ckpt.save(self.get_temp_dir())
   export_module_cls = export_savedmodel.lookup_export_module(task)
   serving_params = {
       "cls_head_name": "foo",
       "parse_sequence_length": 10,
       "max_predictions_per_seq": 5
   }
   params = export_module_cls.Params(**serving_params)
   export_module = export_module_cls(params=params, model=model)
   export_dir = export_savedmodel_util.export(
       export_module,
       function_keys={
           "serve": "serving_default",
           "serve_examples": "serving_examples"
       },
       checkpoint_path=ckpt_path,
       export_savedmodel_dir=self.get_temp_dir())
   imported = tf.saved_model.load(export_dir)
   self.assertSameElements(imported.signatures.keys(),
                           ["serving_default", "serving_examples"])
   serving_fn = imported.signatures["serving_default"]
   dummy_ids = tf.ones((1, 10), dtype=tf.int32)
   dummy_pos = tf.ones((1, 5), dtype=tf.int32)
   outputs = serving_fn(
       input_word_ids=dummy_ids,
       input_mask=dummy_ids,
       input_type_ids=dummy_ids,
       masked_lm_positions=dummy_pos)
   self.assertEqual(outputs["classification"].shape, (1, 2))
Exemplo n.º 5
0
def main(_):
    serving_params = yaml.load(hyperparams.nested_csv_str_to_json_str(
        FLAGS.serving_params),
                               Loader=yaml.FullLoader)
    export_module = create_export_module(task_name=FLAGS.task_name,
                                         config_file=FLAGS.config_file,
                                         serving_params=serving_params)
    export_dir = export_savedmodel_util.export(
        export_module,
        function_keys=[FLAGS.function_keys],
        checkpoint_path=FLAGS.checkpoint_path,
        export_savedmodel_dir=FLAGS.export_savedmodel_dir,
        module_key=FLAGS.module_key)

    if FLAGS.convert_tpu:
        # pylint: disable=g-import-not-at-top
        from cloud_tpu.inference_converter_v2 import converter_options_v2_pb2
        from cloud_tpu.inference_converter_v2.python import converter

        tpu_dir = os.path.join(export_dir, "tpu")
        batch_options = []
        if FLAGS.allowed_batch_size is not None:
            allowed_batch_sizes = sorted(FLAGS.allowed_batch_size)
            batch_option = converter_options_v2_pb2.BatchOptionsV2(
                num_batch_threads=FLAGS.num_batch_threads,
                max_batch_size=allowed_batch_sizes[-1],
                batch_timeout_micros=FLAGS.batch_timeout_micros,
                allowed_batch_sizes=allowed_batch_sizes,
                max_enqueued_batches=FLAGS.max_enqueued_batches)
            batch_options.append(batch_option)

        converter_options = converter_options_v2_pb2.ConverterOptionsV2(
            tpu_functions=[
                converter_options_v2_pb2.TpuFunction(
                    function_alias="tpu_candidate")
            ],
            batch_options=batch_options,
        )

        converter.ConvertSavedModel(export_dir, tpu_dir, converter_options)