示例#1
0
 def __init__(self,
              dstc8_data_dir,
              train_file_range,
              dev_file_range,
              test_file_range,
              vocab_file,
              do_lower_case,
              max_seq_length=DEFAULT_MAX_SEQ_LENGTH,
              log_data_warnings=False):
     self.dstc8_data_dir = dstc8_data_dir
     self._log_data_warnings = log_data_warnings
     self._file_ranges = {
         "train": train_file_range,
         "dev": dev_file_range,
         "test": test_file_range,
     }
     # BERT tokenizer
     self._tokenizer = tokenization.FullTokenizer(
         vocab_file=vocab_file, do_lower_case=do_lower_case)
     self._max_seq_length = max_seq_length
示例#2
0
def _create_schema_embeddings(bert_config, schema_embedding_file,
                              dataset_config):
    """Create schema embeddings and save it into file."""
    if not tf.io.gfile.exists(FLAGS.schema_embedding_dir):
        tf.io.gfile.makedirs(FLAGS.schema_embedding_dir)
    is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
    schema_emb_run_config = tf.compat.v1.estimator.tpu.RunConfig(
        master=FLAGS.master,
        tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    schema_json_path = os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,
                                    "schema.json")
    schemas = schema.Schema(schema_json_path)

    # Prepare BERT model for embedding a natural language descriptions.
    bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")
    schema_emb_model_fn = extract_schema_embedding.model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=bert_init_ckpt,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    schema_emb_estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=schema_emb_model_fn,
        config=schema_emb_run_config,
        predict_batch_size=FLAGS.predict_batch_size)
    vocab_file = os.path.join(FLAGS.bert_ckpt_dir, "vocab.txt")
    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)
    emb_generator = extract_schema_embedding.SchemaEmbeddingGenerator(
        tokenizer, schema_emb_estimator, FLAGS.max_seq_length)
    emb_generator.save_embeddings(schemas, schema_embedding_file,
                                  dataset_config)