def preprocess(): """Preprocesses SQUAD data.""" sp_model = spm.SentencePieceProcessor() sp_model.Load(FLAGS.spiece_model_file) spm_basename = os.path.basename(FLAGS.spiece_model_file) if FLAGS.create_train_data: train_rec_file = os.path.join( FLAGS.output_dir, "{}.{}.slen-{}.qlen-{}.train.tf_record".format(spm_basename, FLAGS.proc_id, FLAGS.max_seq_length, FLAGS.max_query_length)) logging.info("Read examples from %s", FLAGS.train_file) train_examples = squad_utils.read_squad_examples( FLAGS.train_file, is_training=True) train_examples = train_examples[FLAGS.proc_id::FLAGS.num_proc] # Pre-shuffle the input to avoid having to make a very large shuffle # buffer in the `input_fn`. random.shuffle(train_examples) write_to_logging = "Write to " + train_rec_file logging.info(write_to_logging) train_writer = squad_utils.FeatureWriter( filename=train_rec_file, is_training=True) squad_utils.convert_examples_to_features( examples=train_examples, sp_model=sp_model, max_seq_length=FLAGS.max_seq_length, doc_stride=FLAGS.doc_stride, max_query_length=FLAGS.max_query_length, is_training=True, output_fn=train_writer.process_feature, uncased=FLAGS.uncased) train_writer.close() if FLAGS.create_eval_data: eval_examples = squad_utils.read_squad_examples( FLAGS.predict_file, is_training=False) squad_utils.create_eval_data(spm_basename, sp_model, eval_examples, FLAGS.max_seq_length, FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased, FLAGS.output_dir)
def run_evaluation(strategy, test_input_fn, eval_steps, input_meta_data, model, step, eval_summary_writer=None): """Run evaluation for SQUAD task. Args: strategy: distribution strategy. test_input_fn: input function for evaluation data. eval_steps: total number of evaluation steps. input_meta_data: input meta data. model: keras model object. step: current training step. eval_summary_writer: summary writer used to record evaluation metrics. """ def _test_step_fn(inputs): """Replicated validation step.""" inputs["mems"] = None res = model(inputs, training=False) return res, inputs["unique_ids"] @tf.function def _run_evaluation(test_iterator): """Runs validation steps.""" res, unique_ids = strategy.experimental_run_v2( _test_step_fn, args=(next(test_iterator), )) return res, unique_ids # pylint: disable=protected-access test_iterator = data_utils._get_input_iterator(test_input_fn, strategy) # pylint: enable=protected-access cur_results = [] eval_examples = squad_utils.read_squad_examples( input_meta_data["predict_file"], is_training=False) with tf.io.gfile.GFile(input_meta_data["predict_file"]) as f: orig_data = json.load(f)["data"] for _ in range(eval_steps): results, unique_ids = _run_evaluation(test_iterator) unique_ids = strategy.experimental_local_results(unique_ids) for result_key in results: results[result_key] = (strategy.experimental_local_results( results[result_key])) for core_i in range(strategy.num_replicas_in_sync): bsz = int(input_meta_data["test_batch_size"] / strategy.num_replicas_in_sync) for j in range(bsz): result = {} for result_key in results: result[result_key] = results[result_key][core_i].numpy()[j] result["unique_ids"] = unique_ids[core_i].numpy()[j] # We appended a fake example into dev set to make data size can be # divided by test_batch_size. Ignores this fake example during # evaluation. if result["unique_ids"] == 1000012047: continue unique_id = int(result["unique_ids"]) start_top_log_probs = ([ float(x) for x in result["start_top_log_probs"].flat ]) start_top_index = [ int(x) for x in result["start_top_index"].flat ] end_top_log_probs = ([ float(x) for x in result["end_top_log_probs"].flat ]) end_top_index = [int(x) for x in result["end_top_index"].flat] cls_logits = float(result["cls_logits"].flat[0]) cur_results.append( squad_utils.RawResult( unique_id=unique_id, start_top_log_probs=start_top_log_probs, start_top_index=start_top_index, end_top_log_probs=end_top_log_probs, end_top_index=end_top_index, cls_logits=cls_logits)) if len(cur_results) % 1000 == 0: logging.info("Processing example: %d", len(cur_results)) output_prediction_file = os.path.join(input_meta_data["predict_dir"], "predictions.json") output_nbest_file = os.path.join(input_meta_data["predict_dir"], "nbest_predictions.json") output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"], "null_odds.json") ret = squad_utils.write_predictions( eval_examples, input_meta_data["eval_features"], cur_results, input_meta_data["n_best_size"], input_meta_data["max_answer_length"], output_prediction_file, output_nbest_file, output_null_log_odds_file, orig_data, input_meta_data["start_n_top"], input_meta_data["end_n_top"]) # Log current result log_str = "Result | " for key, val in ret.items(): log_str += "{} {} | ".format(key, val) logging.info(log_str) if eval_summary_writer: with eval_summary_writer.as_default(): tf.summary.scalar("best_f1", ret["best_f1"], step=step) tf.summary.scalar("best_exact", ret["best_exact"], step=step) eval_summary_writer.flush()
def main(unused_argv): del unused_argv if FLAGS.strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() elif FLAGS.strategy_type == "tpu": cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) else: raise ValueError( "The distribution strategy type is not supported: %s" % FLAGS.strategy_type) if strategy: logging.info("***** Number of cores used : %d", strategy.num_replicas_in_sync) train_input_fn = functools.partial(data_utils.get_squad_input_data, FLAGS.train_batch_size, FLAGS.seq_len, FLAGS.query_len, strategy, True, FLAGS.train_tfrecord_path) test_input_fn = functools.partial(data_utils.get_squad_input_data, FLAGS.test_batch_size, FLAGS.seq_len, FLAGS.query_len, strategy, False, FLAGS.test_tfrecord_path) total_training_steps = FLAGS.train_steps steps_per_loop = FLAGS.iterations eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) optimizer, learning_rate_fn = optimization.create_optimizer( FLAGS.learning_rate, total_training_steps, FLAGS.warmup_steps, adam_epsilon=FLAGS.adam_epsilon) model_config = xlnet_config.XLNetConfig(FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS) input_meta_data = {} input_meta_data["start_n_top"] = FLAGS.start_n_top input_meta_data["end_n_top"] = FLAGS.end_n_top input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["predict_dir"] = FLAGS.predict_dir input_meta_data["n_best_size"] = FLAGS.n_best_size input_meta_data["max_answer_length"] = FLAGS.max_answer_length input_meta_data["test_batch_size"] = FLAGS.test_batch_size input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["mem_len"] = FLAGS.mem_len model_fn = functools.partial(get_qaxlnet_model, model_config, run_config, FLAGS.start_n_top, FLAGS.end_n_top) eval_examples = squad_utils.read_squad_examples(FLAGS.predict_file, is_training=False) if FLAGS.test_feature_path: logging.info("start reading pickle file...") with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f: eval_features = pickle.load(f) logging.info("finishing reading pickle file...") else: sp_model = spm.SentencePieceProcessor() sp_model.LoadFromSerializedProto( tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read()) spm_basename = os.path.basename(FLAGS.spiece_model_file) eval_features = squad_utils.create_eval_data( spm_basename, sp_model, eval_examples, FLAGS.max_seq_length, FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased) with tf.io.gfile.GFile(FLAGS.predict_file) as f: original_data = json.load(f)["data"] eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_examples, eval_features, original_data, eval_steps, input_meta_data) training_utils.train(strategy=strategy, model_fn=model_fn, input_meta_data=input_meta_data, eval_fn=eval_fn, metric_fn=None, train_input_fn=train_input_fn, init_checkpoint=FLAGS.init_checkpoint, init_from_transformerxl=FLAGS.init_from_transformerxl, total_training_steps=total_training_steps, steps_per_loop=steps_per_loop, optimizer=optimizer, learning_rate_fn=learning_rate_fn, model_dir=FLAGS.model_dir, save_steps=FLAGS.save_steps)