def __init__(self, dstc8_data_dir, train_file_range, dev_file_range, test_file_range, vocab_file, do_lower_case, max_seq_length=DEFAULT_MAX_SEQ_LENGTH, log_data_warnings=False): self.dstc8_data_dir = dstc8_data_dir self._log_data_warnings = log_data_warnings self._file_ranges = { "train": train_file_range, "dev": dev_file_range, "test": test_file_range, } # BERT tokenizer self._tokenizer = tokenization.FullTokenizer( vocab_file=vocab_file, do_lower_case=do_lower_case) self._max_seq_length = max_seq_length
def _create_schema_embeddings(bert_config, schema_embedding_file, dataset_config): """Create schema embeddings and save it into file.""" if not tf.io.gfile.exists(FLAGS.schema_embedding_dir): tf.io.gfile.makedirs(FLAGS.schema_embedding_dir) is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 schema_emb_run_config = tf.compat.v1.estimator.tpu.RunConfig( master=FLAGS.master, tpu_config=tf.compat.v1.estimator.tpu.TPUConfig( num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) schema_json_path = os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split, "schema.json") schemas = schema.Schema(schema_json_path) # Prepare BERT model for embedding a natural language descriptions. bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt") schema_emb_model_fn = extract_schema_embedding.model_fn_builder( bert_config=bert_config, init_checkpoint=bert_init_ckpt, use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. schema_emb_estimator = tf.compat.v1.estimator.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=schema_emb_model_fn, config=schema_emb_run_config, predict_batch_size=FLAGS.predict_batch_size) vocab_file = os.path.join(FLAGS.bert_ckpt_dir, "vocab.txt") tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=FLAGS.do_lower_case) emb_generator = extract_schema_embedding.SchemaEmbeddingGenerator( tokenizer, schema_emb_estimator, FLAGS.max_seq_length) emb_generator.save_embeddings(schemas, schema_embedding_file, dataset_config)