def setup(self): nltk.data.path.append(self.nltk_data_path) self.tokenizer = tokenization.FullTokenizer(self.spm_model_path) self.nltk_tokenizer = nltk.TreebankWordTokenizer() self.nltk_pos_types = { 'PERSON', 'ORGANIZATION', 'FACILITY', 'GPE', 'GSP' }
def main(_): logging.set_verbosity(logging.INFO) validate_flags() tf.io.gfile.makedirs(FLAGS.output_dir) for flag in FLAGS.flags_by_module_dict()[sys.argv[0]]: logging.info(" %s = %s", flag.name, flag.value) model_config = config.get_model_config( model_dir=FLAGS.output_dir, source_file=FLAGS.read_it_twice_bert_config_file, source_base64=FLAGS.read_it_twice_bert_config_base64, write_from_source=FLAGS.do_train) if FLAGS.checkpoint is not None: assert not FLAGS.do_train assert FLAGS.do_eval if FLAGS.cross_attention_top_k is not None: model_config = dataclasses.replace( model_config, cross_attention_top_k=FLAGS.cross_attention_top_k) input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.io.gfile.glob(input_pattern)) logging.info("*** Input Files ***") for input_file in input_files: logging.info(" %s", input_file) num_blocks_per_example, block_length = input_utils.get_block_params_from_input_file( input_files[0]) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) # Basically, quoting the answer above # PER_HOST_V1: iterator.get_next() is called 1 time with per_worker_batch_size # PER_HOST_V2: iterator.get_next() is called 8 times with per_core_batch_size # pylint: enable=line-too-long is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1 run_config = tf.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, # Keep all checkpoints keep_checkpoint_max=None, tpu_config=tf.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, tpu_job_name=FLAGS.tpu_job_name, per_host_input_for_training=is_per_host, experimental_host_call_every_n_steps=FLAGS.steps_per_summary)) # TODO(urikz): Is there a better way to compute the number of tasks? # the code below doesn't work because `tpu_cluster_resolver.cluster_spec()` # returns None. Therefore, I have to pass number of total tasks via CLI arg. # num_tpu_tasks = tpu_cluster_resolver.cluster_spec().num_tasks() batch_size = (FLAGS.num_tpu_tasks or 1) * num_blocks_per_example num_train_examples = input_utils.get_num_examples_in_tf_records(input_files) num_train_steps = int(num_train_examples * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) logging.info("***** Input configuration *****") logging.info(" Number of blocks per example = %d", num_blocks_per_example) logging.info(" Block length = %d", block_length) logging.info(" Number of TPU tasks = %d", FLAGS.num_tpu_tasks or 1) logging.info(" Batch size = %d", batch_size) logging.info(" Number of TPU cores = %d", FLAGS.num_tpu_cores or 0) logging.info(" Number training steps = %d", num_train_steps) logging.info(" Number warmup steps = %d", num_warmup_steps) model_fn = model_fn_builder( model_config=model_config, padding_token_id=FLAGS.padding_token_id, enable_side_inputs=FLAGS.enable_side_inputs, num_replicas_concat=FLAGS.num_tpu_cores, cross_block_attention_mode=FLAGS.cross_block_attention_mode, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_one_hot_embeddings, optimizer=FLAGS.optimizer, poly_power=FLAGS.poly_power, start_warmup_step=FLAGS.start_warmup_step, learning_rate_schedule=FLAGS.learning_rate_schedule, nbest_logits_for_eval=FLAGS.decode_top_k) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.estimator.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size) training_done_path = os.path.join(FLAGS.output_dir, "training_done") if FLAGS.do_train: logging.info("***** Running training *****") train_input_fn = input_fn_builder(input_files=input_files, is_training=True) estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) # Write file to signal training is done. with tf.gfile.GFile(training_done_path, "w") as writer: writer.write("\n") if FLAGS.do_eval: logging.info("***** Running evaluation *****") eval_input_fn = input_fn_builder(input_files=input_files, is_training=False) question_ids, ground_truth = read_question_answer_json(FLAGS.input_json) tokenizer = tokenization.FullTokenizer(FLAGS.spm_model_path) logging.info("Loaded SentencePiece model from %s", FLAGS.spm_model_path) # Writer for TensorBoard. summary_writer = tf.summary.FileWriter( os.path.join(FLAGS.output_dir, "eval_metrics")) if not FLAGS.checkpoint: # for checkpoint_path in _get_all_checkpoints(FLAGS.output_dir): checkpoint_iter = tf.train.checkpoints_iterator( FLAGS.output_dir, min_interval_secs=5 * 60, timeout=8 * 60 * 60) else: checkpoint_iter = [FLAGS.checkpoint] for checkpoint_path in checkpoint_iter: start_time = time.time() global_step = _get_global_step_for_checkpoint(checkpoint_path) if global_step == 0: continue logging.info("Starting eval on step %d on checkpoint: %s", global_step, checkpoint_path) try: nbest_predictions = collections.OrderedDict() yesno_logits, yesno_labels = {}, {} supporting_fact_logits, supporting_fact_labels = {}, {} for prediction in estimator.predict( eval_input_fn, checkpoint_path=checkpoint_path, yield_single_examples=True): block_id = prediction["block_ids"] if block_id == 0: # Padding document continue question_id = question_ids[block_id] if question_id not in nbest_predictions: nbest_predictions[question_id] = {} yesno_logits[question_id] = [] yesno_labels[question_id] = [] supporting_fact_logits[question_id] = [] supporting_fact_labels[question_id] = [] yesno_logits[question_id].append(prediction["yesno_logits"].tolist()) yesno_labels[question_id].append(prediction["answer_type"].tolist()) supporting_fact_logits[question_id].append( prediction["supporting_fact_logits"].tolist()) supporting_fact_labels[question_id].append( prediction["is_supporting_fact"].tolist()) token_ids = prediction["token_ids"] for begin_index, begin_logit in zip( prediction["begin_logits_indices"], prediction["begin_logits_values"]): for end_index, end_logit in zip(prediction["end_logits_indices"], prediction["end_logits_values"]): if begin_index > end_index or end_index - begin_index + 1 > FLAGS.decode_max_size: continue answer = "".join( tokenizer.convert_ids_to_tokens([ int(token_id) for token_id in token_ids[begin_index:end_index + 1] ])) answer = answer.replace(tokenization.SPIECE_UNDERLINE, " ").strip() if not answer: continue normalized_answer = evaluation.normalize_answer(answer) if normalized_answer not in nbest_predictions[question_id]: nbest_predictions[question_id][normalized_answer] = [] nbest_predictions[question_id][normalized_answer].append( begin_logit + end_logit) except tf.errors.NotFoundError: # Since the coordinator is on a different job than the TPU worker, # sometimes the TPU worker does not finish initializing until long after # the CPU job tells it to start evaluating. In this case, the checkpoint # file could have been deleted already. tf.logging.info("Checkpoint %s no longer exists, skipping checkpoint", checkpoint_path) continue nbest_predictions_probs = _convert_prediction_logits_to_probs( nbest_predictions) best_predictions_max = _get_best_predictions(nbest_predictions_probs, max) for question_id in yesno_logits: if question_id in best_predictions_max: span_answer = best_predictions_max[question_id] else: span_answer = None best_predictions_max[question_id] = { "yesno_logits": yesno_logits[question_id], "yesno_labels": yesno_labels[question_id], "supporting_fact_logits": supporting_fact_logits[question_id], "supporting_fact_labels": supporting_fact_labels[question_id], } if span_answer is not None: best_predictions_max[question_id]["span_answer"] = span_answer with tf.gfile.GFile(checkpoint_path + ".best_predictions_max.json", "w") as f: json.dump(best_predictions_max, f, indent=2) best_predictions_max_results = evaluation.make_predictions_and_eval( ground_truth, best_predictions_max) write_eval_results(global_step, best_predictions_max_results, "max", summary_writer) if tf.io.gfile.exists(training_done_path): # Break if the checkpoint we just processed is the last one. last_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) if last_checkpoint is None: continue last_global_step = _get_global_step_for_checkpoint(last_checkpoint) if global_step == last_global_step: break global_step = _get_global_step_for_checkpoint(checkpoint_path) logging.info("Finished eval on step %d in %d seconds", global_step, time.time() - start_time)
def setup(self): self.tokenizer = tokenization.FullTokenizer( self.spm_model_path, self.vocab_path, do_lower_case=self.do_lower_case)
def main(_): logging.set_verbosity(logging.INFO) for flag in FLAGS.flags_by_module_dict()[sys.argv[0]]: logging.info(" %s = %s", flag.name, flag.value) if not FLAGS.do_train and not FLAGS.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") tf.io.gfile.makedirs(FLAGS.output_dir) model_config = config.get_model_config( model_dir=FLAGS.output_dir, source_file=FLAGS.read_it_twice_bert_config_file, source_base64=None, write_from_source=False) if FLAGS.cross_attention_top_k is not None: model_config = dataclasses.replace( model_config, cross_attention_top_k=FLAGS.cross_attention_top_k) input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.io.gfile.glob(input_pattern)) logging.info("*** Input Files ***") for input_file in input_files: logging.info(" %s", input_file) num_blocks_per_example, block_length = input_utils.get_block_params_from_input_file( input_files[0]) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) # PER_HOST_V1: iterator.get_next() is called 1 time with per_worker_batch_size # PER_HOST_V2: iterator.get_next() is called 8 times with per_core_batch_size is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1 run_config = tf.estimator.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, tpu_config=tf.estimator.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, tpu_job_name=FLAGS.tpu_job_name, per_host_input_for_training=is_per_host, experimental_host_call_every_n_steps=FLAGS.steps_per_summary)) # TODO(urikz): Is there a better way to compute the number of tasks? # the code below doesn't work because `tpu_cluster_resolver.cluster_spec()` # returns None. Therefore, I have to pass number of total tasks via CLI arg. # num_tpu_tasks = tpu_cluster_resolver.cluster_spec().num_tasks() batch_size = (FLAGS.num_tpu_tasks or 1) * num_blocks_per_example num_train_examples = input_utils.get_num_examples_in_tf_records( input_files) num_train_steps = int(num_train_examples * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) logging.info("***** Input configuration *****") logging.info(" Number of blocks per example = %d", num_blocks_per_example) logging.info(" Block length = %d", block_length) logging.info(" Number of TPU tasks = %d", FLAGS.num_tpu_tasks or 1) logging.info(" Batch size = %d", batch_size) logging.info(" Number of TPU cores = %d", FLAGS.num_tpu_cores or 0) logging.info(" Number training steps = %d", num_train_steps) logging.info(" Number warmup steps = %d", num_warmup_steps) model_fn = model_fn_builder( model_config=model_config, padding_token_id=FLAGS.padding_token_id, enable_side_inputs=FLAGS.enable_side_inputs, num_replicas_concat=FLAGS.num_tpu_cores, cross_block_attention_mode=FLAGS.cross_block_attention_mode, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_one_hot_embeddings, optimizer=FLAGS.optimizer, poly_power=FLAGS.poly_power, start_warmup_step=FLAGS.start_warmup_step, learning_rate_schedule=FLAGS.learning_rate_schedule, nbest_logits_for_eval=FLAGS.decode_top_k) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.estimator.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size) training_done_path = os.path.join(FLAGS.output_dir, "training_done") if FLAGS.do_train: logging.info("***** Running training *****") train_input_fn = input_fn_builder(input_files=input_files, is_training=True) estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) # Write file to signal training is done. with tf.gfile.GFile(training_done_path, "w") as writer: writer.write("\n") if FLAGS.do_eval: logging.info("***** Running evaluation *****") eval_input_fn = input_fn_builder(input_files=input_files, is_training=False) if FLAGS.eval_json_path is None: raise ValueError("Must specify `eval_json_path` for eval.") if FLAGS.eval_data_split == "test": do_eval = False elif FLAGS.eval_data_split == "valid": do_eval = True else: raise Exception("Unknown mode: " + FLAGS.eval_data_split) with tf.io.gfile.GFile(FLAGS.eval_json_path) as f: data = json.load(f)["Data"] if do_eval: ground_truth = { datum["QuestionId"]: datum["Answer"] for datum in data } # We skip the first question ID as it corresponds on a padding document. question_ids = [None] + [datum["QuestionId"] for datum in data] if do_eval: logging.info("Loaded %d questions for evaluation from %s", len(ground_truth), FLAGS.eval_json_path) tokenizer = tokenization.FullTokenizer(FLAGS.spm_model_path) logging.info("Loaded SentencePiece model from %s", FLAGS.spm_model_path) # Writer for TensorBoard. summary_writer = tf.summary.FileWriter( os.path.join(FLAGS.output_dir, "eval_metrics")) if _is_checkpoint_in_the_model_dir(): logging.info("FINAL EVALUATION: checkpoint = %s, data split = %s", FLAGS.init_checkpoint, FLAGS.eval_data_split) checkpoints = [FLAGS.init_checkpoint] else: # for checkpoint_path in _get_all_checkpoints(FLAGS.output_dir): checkpoints = tf.train.checkpoints_iterator(FLAGS.output_dir, min_interval_secs=5 * 60, timeout=8 * 60 * 60) for checkpoint_path in checkpoints: start_time = time.time() global_step = _get_global_step_for_checkpoint(checkpoint_path) logging.info("Starting eval on step %d on checkpoint: %s", global_step, checkpoint_path) nbest_predictions = collections.OrderedDict() try: for prediction_index, prediction in enumerate( estimator.predict(eval_input_fn, checkpoint_path=checkpoint_path, yield_single_examples=True)): if prediction_index % 100000 == 0: logging.info("Processing example: %d", prediction_index) block_id = prediction["block_ids"] if block_id == 0: # Padding document continue question_id = question_ids[block_id] if question_id not in nbest_predictions: nbest_predictions[question_id] = {} token_ids = prediction["token_ids"] for begin_index, begin_logit in zip( prediction["begin_logits_indices"], prediction["begin_logits_values"]): for end_index, end_logit in zip( prediction["end_logits_indices"], prediction["end_logits_values"]): if begin_index > end_index or end_index - begin_index + 1 > FLAGS.decode_max_size: continue answer = "".join( tokenizer.convert_ids_to_tokens([ int(token_id) for token_id in token_ids[begin_index:end_index + 1] ])) answer = evaluation.normalize_answer( answer.replace(tokenization.SPIECE_UNDERLINE, " ")) if not answer: continue if answer not in nbest_predictions[question_id]: nbest_predictions[question_id][answer] = [] nbest_predictions[question_id][answer].append( begin_logit + end_logit) nbest_predictions_probs = _convert_prediction_logits_to_probs( nbest_predictions) best_predictions_max = _get_best_predictions( nbest_predictions_probs, max) with tf.gfile.GFile( checkpoint_path + ".%s.best_predictions_max.json" % FLAGS.eval_data_split, "w") as f: json.dump(best_predictions_max, f, indent=2) if do_eval: best_predictions_max_results = evaluation.evaluate_triviaqa( ground_truth, best_predictions_max, mute=False) write_eval_results(global_step, best_predictions_max_results, "max", summary_writer) summary_writer.flush() except tf.errors.NotFoundError: # Since the coordinator is on a different job than the TPU worker, # sometimes the TPU worker does not finish initializing until long after # the CPU job tells it to start evaluating. In this case, the checkpoint # file could have been deleted already. tf.logging.info( "Checkpoint %s no longer exists, skipping checkpoint", checkpoint_path) continue if tf.io.gfile.exists(training_done_path): # Break if the checkpoint we just processed is the last one. last_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir) if last_checkpoint is None: continue last_global_step = _get_global_step_for_checkpoint( last_checkpoint) if global_step == last_global_step: break global_step = _get_global_step_for_checkpoint(checkpoint_path) logging.info("Finished eval on step %d in %d seconds", global_step, time.time() - start_time)