예제 #1
0
 def _do_export(self,
                vocab,
                do_lower_case,
                default_seq_length=128,
                tokenize_with_offsets=True,
                use_sp_model=False,
                experimental_disable_assert=False):
     """Runs SavedModel export and returns the export_path."""
     export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
     vocab_file = sp_model_file = None
     if use_sp_model:
         sp_model_file = self._make_sp_model_file(vocab)
     else:
         vocab_file = self._make_vocab_file(vocab)
     export_tfhub_lib.export_preprocessing(
         export_path,
         vocab_file=vocab_file,
         sp_model_file=sp_model_file,
         do_lower_case=do_lower_case,
         tokenize_with_offsets=tokenize_with_offsets,
         default_seq_length=default_seq_length,
         experimental_disable_assert=experimental_disable_assert)
     # Invalidate the original filename to verify loading from the SavedModel.
     tf.io.gfile.remove(sp_model_file or vocab_file)
     return export_path
예제 #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)

    if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
        raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
                         "can be specified, but got %s and %s." %
                         (FLAGS.vocab_file, FLAGS.sp_model_file))
    do_lower_case = export_tfhub_lib.get_do_lower_case(FLAGS.do_lower_case,
                                                       FLAGS.vocab_file,
                                                       FLAGS.sp_model_file)

    if FLAGS.export_type in ("model", "model_with_mlm"):
        if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
            raise ValueError(
                "Exactly one of `bert_config_file` and "
                "`encoder_config_file` can be specified, but got "
                "%s and %s." %
                (FLAGS.bert_config_file, FLAGS.encoder_config_file))
        if FLAGS.bert_config_file:
            bert_config = configs.BertConfig.from_json_file(
                FLAGS.bert_config_file)
            encoder_config = None
        else:
            bert_config = None
            encoder_config = encoders.EncoderConfig()
            encoder_config = hyperparams.override_params_dict(
                encoder_config, FLAGS.encoder_config_file, is_strict=True)
        export_tfhub_lib.export_model(
            FLAGS.export_path,
            bert_config=bert_config,
            encoder_config=encoder_config,
            model_checkpoint_path=FLAGS.model_checkpoint_path,
            vocab_file=FLAGS.vocab_file,
            sp_model_file=FLAGS.sp_model_file,
            do_lower_case=do_lower_case,
            with_mlm=FLAGS.export_type == "model_with_mlm",
            copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)

    elif FLAGS.export_type == "preprocessing":
        export_tfhub_lib.export_preprocessing(
            FLAGS.export_path,
            vocab_file=FLAGS.vocab_file,
            sp_model_file=FLAGS.sp_model_file,
            do_lower_case=do_lower_case,
            default_seq_length=FLAGS.default_seq_length,
            tokenize_with_offsets=FLAGS.tokenize_with_offsets,
            experimental_disable_assert=FLAGS.
            experimental_disable_assert_in_preprocessing)

    else:
        raise app.UsageError("Unknown value '%s' for flag --export_type" %
                             FLAGS.export_type)