Esempio n. 1
0
 def setUp(self):
     self._processor = data_utils.Dstc8DataProcessor(
         dstc8_data_dir=_TEST_DATA_DIR,
         train_file_range=range(1),
         dev_file_range=None,
         test_file_range=None,
         vocab_file=_VOCAB_FILE,
         do_lower_case=_DO_LOWER_CASE)
     super(Dstc8DataProcessorTest, self).setUp()
 def setUp(self):
     self._processor = data_utils.Dstc8DataProcessor(
         dstc8_data_dir=_TEST_DATA_DIR,
         dataset_config=config.DatasetConfig(file_ranges={
             'train': range(1),
             'dev': None,
             'test': None
         },
                                             max_num_cat_slot=6,
                                             max_num_noncat_slot=6,
                                             max_num_value_per_cat_slot=4,
                                             max_num_intent=2),
         vocab_file=_VOCAB_FILE,
         do_lower_case=_DO_LOWER_CASE)
     super(Dstc8DataProcessorTest, self).setUp()
Esempio n. 3
0
def main(_):
  vocab_file = os.path.join(FLAGS.bert_ckpt_dir, "vocab.txt")
  task_name = FLAGS.task_name.lower()
  if task_name not in data_utils.FILE_RANGES:
    raise ValueError("Task not found: %s" % (task_name))

  processor = data_utils.Dstc8DataProcessor(
      FLAGS.dstc8_data_dir,
      train_file_range=data_utils.FILE_RANGES[task_name]["train"],
      dev_file_range=data_utils.FILE_RANGES[task_name]["dev"],
      test_file_range=data_utils.FILE_RANGES[task_name]["test"],
      vocab_file=vocab_file,
      do_lower_case=FLAGS.do_lower_case,
      max_seq_length=FLAGS.max_seq_length,
      log_data_warnings=FLAGS.log_data_warnings)

  # Generate the dialogue examples if needed or specified.
  dial_file_name = "{}_{}_examples.tf_record".format(task_name,
                                                     FLAGS.dataset_split)
  dial_file = os.path.join(FLAGS.dialogues_example_dir, dial_file_name)
  if not tf.io.gfile.exists(dial_file) or FLAGS.overwrite_dial_file:
    tf.compat.v1.logging.info("Start generating the dialogue examples.")
    _create_dialog_examples(processor, dial_file)
    tf.compat.v1.logging.info("Finish generating the dialogue examples.")

  # Generate the schema embeddings if needed or specified.
  bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")
  tokenization.validate_case_matches_checkpoint(
      do_lower_case=FLAGS.do_lower_case, init_checkpoint=bert_init_ckpt)

  bert_config = modeling.BertConfig.from_json_file(
      os.path.join(FLAGS.bert_ckpt_dir, "bert_config.json"))
  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  schema_embedding_file = os.path.join(
      FLAGS.schema_embedding_dir,
      "{}_pretrained_schema_embedding.npy".format(FLAGS.dataset_split))
  if (not tf.io.gfile.exists(schema_embedding_file) or
      FLAGS.overwrite_schema_emb_file):
    tf.compat.v1.logging.info("Start generating the schema embeddings.")
    _create_schema_embeddings(bert_config, schema_embedding_file)
    tf.compat.v1.logging.info("Finish generating the schema embeddings.")

  # Create estimator for training or inference.
  tf.io.gfile.makedirs(FLAGS.output_dir)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=None,
      tpu_config=tf.contrib.tpu.TPUConfig(
          # Recommended value is number of global steps for next checkpoint.
          iterations_per_loop=FLAGS.save_checkpoints_steps,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.run_mode == "train":
    num_train_examples = processor.get_num_dialog_examples(FLAGS.dataset_split)
    num_train_steps = int(num_train_examples / FLAGS.train_batch_size *
                          FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

  bert_init_ckpt = os.path.join(FLAGS.bert_ckpt_dir, "bert_model.ckpt")
  model_fn = _model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=bert_init_ckpt,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  if FLAGS.run_mode == "train":
    # Train the model.
    tf.compat.v1.logging.info("***** Running training *****")
    tf.compat.v1.logging.info("  Num dial examples = %d", num_train_examples)
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
    train_input_fn = _file_based_input_fn_builder(
        input_dial_file=dial_file,
        schema_embedding_file=schema_embedding_file,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
  elif FLAGS.run_mode == "predict":
    # Run inference to obtain model predictions.
    num_actual_predict_examples = processor.get_num_dialog_examples(
        FLAGS.dataset_split)

    tf.compat.v1.logging.info("***** Running prediction *****")
    tf.compat.v1.logging.info("  Num actual examples = %d",
                              num_actual_predict_examples)
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_input_fn = _file_based_input_fn_builder(
        input_dial_file=dial_file,
        schema_embedding_file=schema_embedding_file,
        is_training=False,
        drop_remainder=FLAGS.use_tpu)

    input_json_files = [
        os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,
                     "dialogues_{:03d}.json".format(fid))
        for fid in data_utils.FILE_RANGES[FLAGS.task_name][FLAGS.dataset_split]
    ]
    schema_json_file = os.path.join(FLAGS.dstc8_data_dir, FLAGS.dataset_split,
                                    "schema.json")

    ckpt_nums = [num for num in FLAGS.eval_ckpt.split(",") if num]
    if not ckpt_nums:
      raise ValueError("No checkpoints assigned for prediction.")
    for ckpt_num in ckpt_nums:
      tf.compat.v1.logging.info("***** Predict results for %s set *****",
                                FLAGS.dataset_split)

      predictions = estimator.predict(
          input_fn=predict_input_fn,
          checkpoint_path=os.path.join(FLAGS.output_dir,
                                       "model.ckpt-%s" % ckpt_num))

      # Write predictions to file in DSTC8 format.
      dataset_mark = os.path.basename(FLAGS.dstc8_data_dir)
      prediction_dir = os.path.join(
          FLAGS.output_dir, "pred_res_{}_{}_{}_{}".format(
              int(ckpt_num), FLAGS.dataset_split, task_name, dataset_mark))
      if not tf.io.gfile.exists(prediction_dir):
        tf.io.gfile.makedirs(prediction_dir)
      pred_utils.write_predictions_to_file(predictions, input_json_files,
                                           schema_json_file, prediction_dir)