def preprocess():
  """Preprocesses SQUAD data."""
  sp_model = spm.SentencePieceProcessor()
  sp_model.Load(FLAGS.spiece_model_file)
  spm_basename = os.path.basename(FLAGS.spiece_model_file)
  if FLAGS.create_train_data:
    train_rec_file = os.path.join(
        FLAGS.output_dir,
        "{}.{}.slen-{}.qlen-{}.train.tf_record".format(spm_basename,
                                                       FLAGS.proc_id,
                                                       FLAGS.max_seq_length,
                                                       FLAGS.max_query_length))

    logging.info("Read examples from %s", FLAGS.train_file)
    train_examples = squad_utils.read_squad_examples(
        FLAGS.train_file, is_training=True)
    train_examples = train_examples[FLAGS.proc_id::FLAGS.num_proc]

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in the `input_fn`.
    random.shuffle(train_examples)
    write_to_logging = "Write to " + train_rec_file
    logging.info(write_to_logging)
    train_writer = squad_utils.FeatureWriter(
        filename=train_rec_file, is_training=True)
    squad_utils.convert_examples_to_features(
        examples=train_examples,
        sp_model=sp_model,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=True,
        output_fn=train_writer.process_feature,
        uncased=FLAGS.uncased)
    train_writer.close()
  if FLAGS.create_eval_data:
    eval_examples = squad_utils.read_squad_examples(
        FLAGS.predict_file, is_training=False)
    squad_utils.create_eval_data(spm_basename, sp_model, eval_examples,
                                 FLAGS.max_seq_length, FLAGS.max_query_length,
                                 FLAGS.doc_stride, FLAGS.uncased,
                                 FLAGS.output_dir)
Example #2
0
def run_evaluation(strategy,
                   test_input_fn,
                   eval_steps,
                   input_meta_data,
                   model,
                   step,
                   eval_summary_writer=None):
    """Run evaluation for SQUAD task.

  Args:
    strategy: distribution strategy.
    test_input_fn: input function for evaluation data.
    eval_steps: total number of evaluation steps.
    input_meta_data: input meta data.
    model: keras model object.
    step: current training step.
    eval_summary_writer: summary writer used to record evaluation metrics.

  """
    def _test_step_fn(inputs):
        """Replicated validation step."""

        inputs["mems"] = None
        res = model(inputs, training=False)
        return res, inputs["unique_ids"]

    @tf.function
    def _run_evaluation(test_iterator):
        """Runs validation steps."""
        res, unique_ids = strategy.experimental_run_v2(
            _test_step_fn, args=(next(test_iterator), ))
        return res, unique_ids

    # pylint: disable=protected-access
    test_iterator = data_utils._get_input_iterator(test_input_fn, strategy)
    # pylint: enable=protected-access
    cur_results = []
    eval_examples = squad_utils.read_squad_examples(
        input_meta_data["predict_file"], is_training=False)
    with tf.io.gfile.GFile(input_meta_data["predict_file"]) as f:
        orig_data = json.load(f)["data"]

    for _ in range(eval_steps):
        results, unique_ids = _run_evaluation(test_iterator)
        unique_ids = strategy.experimental_local_results(unique_ids)

        for result_key in results:
            results[result_key] = (strategy.experimental_local_results(
                results[result_key]))
        for core_i in range(strategy.num_replicas_in_sync):
            bsz = int(input_meta_data["test_batch_size"] /
                      strategy.num_replicas_in_sync)
            for j in range(bsz):
                result = {}
                for result_key in results:
                    result[result_key] = results[result_key][core_i].numpy()[j]
                result["unique_ids"] = unique_ids[core_i].numpy()[j]
                # We appended a fake example into dev set to make data size can be
                # divided by test_batch_size. Ignores this fake example during
                # evaluation.
                if result["unique_ids"] == 1000012047:
                    continue
                unique_id = int(result["unique_ids"])

                start_top_log_probs = ([
                    float(x) for x in result["start_top_log_probs"].flat
                ])
                start_top_index = [
                    int(x) for x in result["start_top_index"].flat
                ]
                end_top_log_probs = ([
                    float(x) for x in result["end_top_log_probs"].flat
                ])
                end_top_index = [int(x) for x in result["end_top_index"].flat]

                cls_logits = float(result["cls_logits"].flat[0])
                cur_results.append(
                    squad_utils.RawResult(
                        unique_id=unique_id,
                        start_top_log_probs=start_top_log_probs,
                        start_top_index=start_top_index,
                        end_top_log_probs=end_top_log_probs,
                        end_top_index=end_top_index,
                        cls_logits=cls_logits))
                if len(cur_results) % 1000 == 0:
                    logging.info("Processing example: %d", len(cur_results))

    output_prediction_file = os.path.join(input_meta_data["predict_dir"],
                                          "predictions.json")
    output_nbest_file = os.path.join(input_meta_data["predict_dir"],
                                     "nbest_predictions.json")
    output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"],
                                             "null_odds.json")

    ret = squad_utils.write_predictions(
        eval_examples, input_meta_data["eval_features"], cur_results,
        input_meta_data["n_best_size"], input_meta_data["max_answer_length"],
        output_prediction_file, output_nbest_file, output_null_log_odds_file,
        orig_data, input_meta_data["start_n_top"],
        input_meta_data["end_n_top"])

    # Log current result

    log_str = "Result | "
    for key, val in ret.items():
        log_str += "{} {} | ".format(key, val)
    logging.info(log_str)
    if eval_summary_writer:
        with eval_summary_writer.as_default():
            tf.summary.scalar("best_f1", ret["best_f1"], step=step)
            tf.summary.scalar("best_exact", ret["best_exact"], step=step)
            eval_summary_writer.flush()
Example #3
0
def main(unused_argv):
    del unused_argv
    if FLAGS.strategy_type == "mirror":
        strategy = tf.distribute.MirroredStrategy()
    elif FLAGS.strategy_type == "tpu":
        cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            "The distribution strategy type is not supported: %s" %
            FLAGS.strategy_type)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)
    train_input_fn = functools.partial(data_utils.get_squad_input_data,
                                       FLAGS.train_batch_size, FLAGS.seq_len,
                                       FLAGS.query_len, strategy, True,
                                       FLAGS.train_tfrecord_path)

    test_input_fn = functools.partial(data_utils.get_squad_input_data,
                                      FLAGS.test_batch_size, FLAGS.seq_len,
                                      FLAGS.query_len, strategy, False,
                                      FLAGS.test_tfrecord_path)

    total_training_steps = FLAGS.train_steps
    steps_per_loop = FLAGS.iterations
    eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)

    optimizer, learning_rate_fn = optimization.create_optimizer(
        FLAGS.learning_rate,
        total_training_steps,
        FLAGS.warmup_steps,
        adam_epsilon=FLAGS.adam_epsilon)
    model_config = xlnet_config.XLNetConfig(FLAGS)
    run_config = xlnet_config.create_run_config(True, False, FLAGS)
    input_meta_data = {}
    input_meta_data["start_n_top"] = FLAGS.start_n_top
    input_meta_data["end_n_top"] = FLAGS.end_n_top
    input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
    input_meta_data["predict_dir"] = FLAGS.predict_dir
    input_meta_data["n_best_size"] = FLAGS.n_best_size
    input_meta_data["max_answer_length"] = FLAGS.max_answer_length
    input_meta_data["test_batch_size"] = FLAGS.test_batch_size
    input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
                                                 strategy.num_replicas_in_sync)
    input_meta_data["mem_len"] = FLAGS.mem_len
    model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
                                 FLAGS.start_n_top, FLAGS.end_n_top)
    eval_examples = squad_utils.read_squad_examples(FLAGS.predict_file,
                                                    is_training=False)
    if FLAGS.test_feature_path:
        logging.info("start reading pickle file...")
        with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
            eval_features = pickle.load(f)
        logging.info("finishing reading pickle file...")
    else:
        sp_model = spm.SentencePieceProcessor()
        sp_model.LoadFromSerializedProto(
            tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read())
        spm_basename = os.path.basename(FLAGS.spiece_model_file)
        eval_features = squad_utils.create_eval_data(
            spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
            FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)

    with tf.io.gfile.GFile(FLAGS.predict_file) as f:
        original_data = json.load(f)["data"]
    eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
                                eval_examples, eval_features, original_data,
                                eval_steps, input_meta_data)

    training_utils.train(strategy=strategy,
                         model_fn=model_fn,
                         input_meta_data=input_meta_data,
                         eval_fn=eval_fn,
                         metric_fn=None,
                         train_input_fn=train_input_fn,
                         init_checkpoint=FLAGS.init_checkpoint,
                         init_from_transformerxl=FLAGS.init_from_transformerxl,
                         total_training_steps=total_training_steps,
                         steps_per_loop=steps_per_loop,
                         optimizer=optimizer,
                         learning_rate_fn=learning_rate_fn,
                         model_dir=FLAGS.model_dir,
                         save_steps=FLAGS.save_steps)