예제 #1
0
 def _verify_annotations(self, annotations, answer_set):
     should_update = False
     new_annotations = []
     filtered_annotations = set()
     for annotation in annotations:
         if (annotation.type == 0 and evaluation.normalize_answer(
                 annotation.text) not in answer_set):
             filtered_annotations.add(annotation.text)
             should_update = True
         else:
             new_annotations.append(annotation)
     return should_update, new_annotations, filtered_annotations
예제 #2
0
def read_question_answer_json(json_path):
  """Read a CVS file into a list of QuestionAnswer objects."""
  # We skip the first question ID as it corresponds on a padding document.
  question_ids = [None]
  ground_truth = {}
  with tf.io.gfile.GFile(json_path) as f:
    data = json.load(f)
    for datum in data:
      question_id = datum["_id"]
      ground_truth[question_id] = evaluation.normalize_answer(datum["answer"])
      question_ids.append(question_id)
  logging.info("Read %d questions from %s", len(ground_truth), json_path)
  return question_ids, ground_truth
예제 #3
0
    def process(self, question_answer_evidence):
        metrics.Metrics.counter(METRICS_NAMESPACE, 'num_questions').inc()

        if self.generate_answers:
            oracle_answers = []
            answer_set = question_answer_evidence.answer.make_answer_set(
                oracle_answers)
            normalized_answer_set = {
                evaluation.normalize_answer(answer)
                for answer in answer_set
            }

        tokenized_question = self._tokenize_text(
            question_answer_evidence.question.value)

        metrics.Metrics.distribution(METRICS_NAMESPACE,
                                     'question_length').update(
                                         len(tokenized_question))

        filtered_annotations = []
        tf_examples = []
        num_answer_annotations = 0
        num_answer_annotations_tokenized = 0
        num_entity_annotations = 0
        num_entity_annotations_tokenized = 0

        no_answer, yes_answer, yes_no_answer = False, False, False
        if question_answer_evidence.answer.values[0] == 'yes':
            metrics.Metrics.counter(METRICS_NAMESPACE,
                                    'num_answer_type.yes').inc()
            yes_no_answer = True
            yes_answer = True
        if question_answer_evidence.answer.values[0] == 'no':
            metrics.Metrics.counter(METRICS_NAMESPACE,
                                    'num_answer_type.no').inc()
            yes_no_answer = True
            no_answer = True
        if yes_no_answer:
            metrics.Metrics.counter(METRICS_NAMESPACE,
                                    'num_answer_type.yes_no').inc()
        else:
            metrics.Metrics.counter(METRICS_NAMESPACE,
                                    'num_answer_type.span').inc()

        for evidence in question_answer_evidence.evidence:
            sentence = self._split_into_sentences(evidence)
            sentence_obj = self._annotate_entities(sentence)
            metrics.Metrics.counter(METRICS_NAMESPACE, 'nltk_entities').inc(
                sentence_obj.num_annotations(1))

            if self.generate_answers and not yes_no_answer:
                annotations = find_answer_annotations(sentence_obj.text,
                                                      answer_set)
                sentence_obj.annotations.extend(annotations)

            document = data_utils.BertDocument(
                sentences=[sentence_obj],
                document_id=question_answer_evidence.question.id)

            num_entity_annotations += document.num_annotations(1)
            num_answer_annotations += document.num_annotations(0)

            tokenized_document = data_utils.tokenize_document_for_bert(
                document, self.tokenizer)

            metrics.Metrics.distribution(
                METRICS_NAMESPACE,
                'tokenized_doc_length_per_paragraph').update(
                    tokenized_document.num_tokens())

            if self.generate_answers and not yes_no_answer:
                assert len(tokenized_document.sentences) == 1
                (should_update, annotations,
                 current_filtered_annotations) = self._verify_annotations(
                     tokenized_document.sentences[0].annotations,
                     normalized_answer_set)
                if should_update:
                    tokenized_document.sentences[0].annotations = annotations
                    # pylint: disable=g-complex-comprehension
                    filtered_annotations.extend([
                        FilteredAnnotation(
                            question=question_answer_evidence.question,
                            answer=question_answer_evidence.answer,
                            annotation=annotation,
                            sentence=''.join(
                                tokenized_document.sentences[0].tokens))
                        for annotation in current_filtered_annotations
                    ])
                    metrics.Metrics.counter(
                        METRICS_NAMESPACE, 'num_filtered_annotations').inc(
                            len(current_filtered_annotations))

            num_entity_annotations_tokenized += tokenized_document.num_annotations(
                1)
            num_answer_annotations_tokenized += tokenized_document.num_annotations(
                0)

            tf_example = tokenized_document.to_tf_strided_large_example(
                overlap_length=self.block_overlap_length,
                block_length=self.block_length,
                padding_token_id=self.padding_token_id,
                prefix_token_ids=tokenized_question,
                max_num_annotations=self.max_num_annotations_per_block)

            if yes_answer:
                assert yes_no_answer
                assert not no_answer
                tf_example.features.feature[
                    'answer_type'].int64_list.value[:] = [1]
            elif no_answer:
                assert yes_no_answer
                assert not yes_answer
                tf_example.features.feature[
                    'answer_type'].int64_list.value[:] = [2]
            else:
                assert not yes_no_answer
                tf_example.features.feature[
                    'answer_type'].int64_list.value[:] = [0]

            if evidence.is_supporting_fact:
                tf_example.features.feature[
                    'is_supporting_fact'].int64_list.value[:] = [1]
            else:
                tf_example.features.feature[
                    'is_supporting_fact'].int64_list.value[:] = [0]

            tf_examples.append(tf_example)

        metrics.Metrics.distribution(METRICS_NAMESPACE,
                                     'num_paragraphs_per_question').update(
                                         len(tf_examples))
        metrics.Metrics.distribution(
            METRICS_NAMESPACE, 'num_answer_annotations_per_question').update(
                num_answer_annotations)
        metrics.Metrics.distribution(
            METRICS_NAMESPACE, 'num_entity_annotations_per_question').update(
                num_entity_annotations)

        if (self.generate_answers and not yes_no_answer
                and num_answer_annotations == 0):
            metrics.Metrics.counter(METRICS_NAMESPACE,
                                    'make_example_status.no_answer').inc()
            yield beam.pvalue.TaggedOutput(MakeExampleOutput.NO_ANSWER,
                                           question_answer_evidence.to_json())
            return

        metrics.Metrics.distribution(
            METRICS_NAMESPACE,
            'num_answer_tokenize_annotations_per_question').update(
                num_answer_annotations_tokenized)
        metrics.Metrics.distribution(
            METRICS_NAMESPACE,
            'num_entity_tokenize_annotations_per_question').update(
                num_entity_annotations_tokenized)
        metrics.Metrics.distribution(METRICS_NAMESPACE,
                                     'num_filtered_annotations').update(
                                         len(filtered_annotations))

        if (self.generate_answers and not yes_no_answer
                and num_answer_annotations_tokenized == 0):
            metrics.Metrics.counter(
                METRICS_NAMESPACE,
                'make_example_status.no_answer_tokenized_annotations').inc()
            yield beam.pvalue.TaggedOutput(
                MakeExampleOutput.NO_ANSWER_TOKENIZED_FILTERED_ANNOTATIONS,
                filtered_annotations)
            return

        yield beam.pvalue.TaggedOutput(
            MakeExampleOutput.SUCCESS_FILTERED_ANNOTATIONS,
            filtered_annotations)

        if len(tf_examples) != 10:
            metrics.Metrics.counter(
                METRICS_NAMESPACE, 'num_not_10_paragraphs_per_question').inc()

        tf_example = tf_examples[0]
        for i in range(1, len(tf_examples)):
            for name in tf_example.features.feature:
                repeated_values = get_repeated_values(name, tf_example)
                extension_values = list(
                    get_repeated_values(name, tf_examples[i]))
                repeated_values.extend(extension_values)
        metrics.Metrics.counter(METRICS_NAMESPACE,
                                'make_example_status.success').inc()
        yield tf_example
예제 #4
0
def main(_):
  logging.set_verbosity(logging.INFO)

  validate_flags()
  tf.io.gfile.makedirs(FLAGS.output_dir)

  for flag in FLAGS.flags_by_module_dict()[sys.argv[0]]:
    logging.info("  %s = %s", flag.name, flag.value)

  model_config = config.get_model_config(
      model_dir=FLAGS.output_dir,
      source_file=FLAGS.read_it_twice_bert_config_file,
      source_base64=FLAGS.read_it_twice_bert_config_base64,
      write_from_source=FLAGS.do_train)


  if FLAGS.checkpoint is not None:
    assert not FLAGS.do_train
    assert FLAGS.do_eval

  if FLAGS.cross_attention_top_k is not None:
    model_config = dataclasses.replace(
        model_config, cross_attention_top_k=FLAGS.cross_attention_top_k)

  input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.io.gfile.glob(input_pattern))

  logging.info("*** Input Files ***")
  for input_file in input_files:
    logging.info("  %s", input_file)

  num_blocks_per_example, block_length = input_utils.get_block_params_from_input_file(
      input_files[0])

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  # Basically, quoting the answer above
  # PER_HOST_V1: iterator.get_next() is called 1 time with per_worker_batch_size
  # PER_HOST_V2: iterator.get_next() is called 8 times with per_core_batch_size
  # pylint: enable=line-too-long
  is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      # Keep all checkpoints
      keep_checkpoint_max=None,
      tpu_config=tf.estimator.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          tpu_job_name=FLAGS.tpu_job_name,
          per_host_input_for_training=is_per_host,
          experimental_host_call_every_n_steps=FLAGS.steps_per_summary))

  # TODO(urikz): Is there a better way to compute the number of tasks?
  # the code below doesn't work because `tpu_cluster_resolver.cluster_spec()`
  # returns None. Therefore, I have to pass number of total tasks via CLI arg.
  # num_tpu_tasks = tpu_cluster_resolver.cluster_spec().num_tasks()
  batch_size = (FLAGS.num_tpu_tasks or 1) * num_blocks_per_example

  num_train_examples = input_utils.get_num_examples_in_tf_records(input_files)
  num_train_steps = int(num_train_examples * FLAGS.num_train_epochs)
  num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

  logging.info("***** Input configuration *****")
  logging.info("  Number of blocks per example = %d", num_blocks_per_example)
  logging.info("  Block length = %d", block_length)
  logging.info("  Number of TPU tasks = %d", FLAGS.num_tpu_tasks or 1)
  logging.info("  Batch size = %d", batch_size)
  logging.info("  Number of TPU cores = %d", FLAGS.num_tpu_cores or 0)
  logging.info("  Number training steps = %d", num_train_steps)
  logging.info("  Number warmup steps = %d", num_warmup_steps)

  model_fn = model_fn_builder(
      model_config=model_config,
      padding_token_id=FLAGS.padding_token_id,
      enable_side_inputs=FLAGS.enable_side_inputs,
      num_replicas_concat=FLAGS.num_tpu_cores,
      cross_block_attention_mode=FLAGS.cross_block_attention_mode,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
      optimizer=FLAGS.optimizer,
      poly_power=FLAGS.poly_power,
      start_warmup_step=FLAGS.start_warmup_step,
      learning_rate_schedule=FLAGS.learning_rate_schedule,
      nbest_logits_for_eval=FLAGS.decode_top_k)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.estimator.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      predict_batch_size=batch_size)

  training_done_path = os.path.join(FLAGS.output_dir, "training_done")

  if FLAGS.do_train:
    logging.info("***** Running training *****")
    train_input_fn = input_fn_builder(input_files=input_files, is_training=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
    # Write file to signal training is done.
    with tf.gfile.GFile(training_done_path, "w") as writer:
      writer.write("\n")

  if FLAGS.do_eval:
    logging.info("***** Running evaluation *****")
    eval_input_fn = input_fn_builder(input_files=input_files, is_training=False)
    question_ids, ground_truth = read_question_answer_json(FLAGS.input_json)
    tokenizer = tokenization.FullTokenizer(FLAGS.spm_model_path)
    logging.info("Loaded SentencePiece model from %s", FLAGS.spm_model_path)

    # Writer for TensorBoard.
    summary_writer = tf.summary.FileWriter(
        os.path.join(FLAGS.output_dir, "eval_metrics"))

    if not FLAGS.checkpoint:
      # for checkpoint_path in _get_all_checkpoints(FLAGS.output_dir):
      checkpoint_iter = tf.train.checkpoints_iterator(
          FLAGS.output_dir, min_interval_secs=5 * 60, timeout=8 * 60 * 60)
    else:
      checkpoint_iter = [FLAGS.checkpoint]

    for checkpoint_path in checkpoint_iter:
      start_time = time.time()
      global_step = _get_global_step_for_checkpoint(checkpoint_path)
      if global_step == 0:
        continue
      logging.info("Starting eval on step %d on checkpoint: %s", global_step,
                   checkpoint_path)
      try:
        nbest_predictions = collections.OrderedDict()
        yesno_logits, yesno_labels = {}, {}
        supporting_fact_logits, supporting_fact_labels = {}, {}

        for prediction in estimator.predict(
            eval_input_fn,
            checkpoint_path=checkpoint_path,
            yield_single_examples=True):
          block_id = prediction["block_ids"]
          if block_id == 0:
            # Padding document
            continue
          question_id = question_ids[block_id]
          if question_id not in nbest_predictions:
            nbest_predictions[question_id] = {}
            yesno_logits[question_id] = []
            yesno_labels[question_id] = []
            supporting_fact_logits[question_id] = []
            supporting_fact_labels[question_id] = []

          yesno_logits[question_id].append(prediction["yesno_logits"].tolist())
          yesno_labels[question_id].append(prediction["answer_type"].tolist())
          supporting_fact_logits[question_id].append(
              prediction["supporting_fact_logits"].tolist())
          supporting_fact_labels[question_id].append(
              prediction["is_supporting_fact"].tolist())

          token_ids = prediction["token_ids"]
          for begin_index, begin_logit in zip(
              prediction["begin_logits_indices"],
              prediction["begin_logits_values"]):
            for end_index, end_logit in zip(prediction["end_logits_indices"],
                                            prediction["end_logits_values"]):
              if begin_index > end_index or end_index - begin_index + 1 > FLAGS.decode_max_size:
                continue
              answer = "".join(
                  tokenizer.convert_ids_to_tokens([
                      int(token_id)
                      for token_id in token_ids[begin_index:end_index + 1]
                  ]))

              answer = answer.replace(tokenization.SPIECE_UNDERLINE,
                                      " ").strip()
              if not answer:
                continue
              normalized_answer = evaluation.normalize_answer(answer)
              if normalized_answer not in nbest_predictions[question_id]:
                nbest_predictions[question_id][normalized_answer] = []
              nbest_predictions[question_id][normalized_answer].append(
                  begin_logit + end_logit)
      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info("Checkpoint %s no longer exists, skipping checkpoint",
                        checkpoint_path)
        continue

      nbest_predictions_probs = _convert_prediction_logits_to_probs(
          nbest_predictions)

      best_predictions_max = _get_best_predictions(nbest_predictions_probs, max)
      for question_id in yesno_logits:
        if question_id in best_predictions_max:
          span_answer = best_predictions_max[question_id]
        else:
          span_answer = None
        best_predictions_max[question_id] = {
            "yesno_logits": yesno_logits[question_id],
            "yesno_labels": yesno_labels[question_id],
            "supporting_fact_logits": supporting_fact_logits[question_id],
            "supporting_fact_labels": supporting_fact_labels[question_id],
        }
        if span_answer is not None:
          best_predictions_max[question_id]["span_answer"] = span_answer

      with tf.gfile.GFile(checkpoint_path + ".best_predictions_max.json",
                          "w") as f:
        json.dump(best_predictions_max, f, indent=2)

      best_predictions_max_results = evaluation.make_predictions_and_eval(
          ground_truth, best_predictions_max)
      write_eval_results(global_step, best_predictions_max_results, "max",
                         summary_writer)

      if tf.io.gfile.exists(training_done_path):
        # Break if the checkpoint we just processed is the last one.
        last_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        if last_checkpoint is None:
          continue
        last_global_step = _get_global_step_for_checkpoint(last_checkpoint)
        if global_step == last_global_step:
          break

      global_step = _get_global_step_for_checkpoint(checkpoint_path)
      logging.info("Finished eval on step %d in %d seconds", global_step,
                   time.time() - start_time)