Ejemplo n.º 1
0
  def run(self) -> None:
    """Main entry point for this class."""
    self.validate_flags_or_throw()

    tf.gfile.MakeDirs(self.output_dir)

    tf.disable_v2_behavior()

    num_train_steps = None
    num_warmup_steps = None
    if self.do_train:
      with tf.gfile.Open(self.record_count_file, "r") as f:
        num_train_features = int(f.read().strip())
      num_train_steps = int(num_train_features / self.train_batch_size *
                            self.num_train_epochs)
      logging.info("record_count_file: %s", self.record_count_file)
      logging.info("num_records (features): %d", num_train_features)
      logging.info("num_train_epochs: %d", self.num_train_epochs)
      logging.info("train_batch_size: %d", self.train_batch_size)
      logging.info("num_train_steps: %d", num_train_steps)

      num_warmup_steps = int(num_train_steps * self.warmup_proportion)

    model_fn = self.get_model_fn(num_train_steps, num_warmup_steps)
    estimator = self.get_estimator(model_fn)

    if self.do_train:
      logging.info("Running training on precomputed features")
      logging.info("  Num split examples = %d", num_train_features)
      logging.info("  Batch size = %d", self.train_batch_size)
      logging.info("  Num steps = %d", num_train_steps)
      train_filenames = tf.gfile.Glob(self.train_records_file)
      train_input_fn = tf_io.input_fn_builder(
          input_file=train_filenames,
          seq_length=self.max_seq_length,
          is_training=True,
          drop_remainder=True)
      estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if self.do_predict:
      self.predict(estimator)
Ejemplo n.º 2
0
  def predict(self, estimator):
    """Run prediction."""
    if not self.precomputed_predict_file:
      predict_examples_iter = preproc.read_tydi_examples(
          input_file=self.predict_file,
          tokenizer=self.get_tokenizer(),
          is_training=False,
          max_passages=self.max_passages,
          max_position=self.max_position,
          fail_on_invalid=self.fail_on_invalid,
          open_fn=tf_io.gopen)
      shards_iter = self.write_tf_feature_files(predict_examples_iter)
    else:
      # Uses zeros for example and feature counts since they're unknown,
      # and we only use them for logging anyway.
      shards_iter = [(1, (self.precomputed_predict_file, 0, 0))]

    # Accumulates all of the prediction results to be written to the output.
    full_tydi_pred_dict = {}
    total_num_examples = 0
    for shard_num, (shard_filename_glob, shard_num_examples,
                    shard_num_features) in shards_iter:
      total_num_examples += shard_num_examples
      logging.info(
          "Shard %d: Running prediction for %s; %d examples, %d features.",
          shard_num, shard_filename_glob, shard_num_examples,
          shard_num_features)

      # Runs the model on the shard and store the individual results.
      # If running predict on TPU, you will need to specify the number of steps.
      predict_input_fn = tf_io.input_fn_builder(
          input_file=tf.gfile.Glob(shard_filename_glob),
          seq_length=self.max_seq_length,
          is_training=False,
          drop_remainder=False)
      all_results = []
      for result in estimator.predict(
          predict_input_fn, yield_single_examples=True):
        if len(all_results) % 10000 == 0:
          logging.info("Shard %d: Predicting for feature %d/%s", shard_num,
                       len(all_results), shard_num_features)
        unique_id = int(result["unique_ids"])
        start_logits = [float(x) for x in result["start_logits"].flat]
        end_logits = [float(x) for x in result["end_logits"].flat]
        answer_type_logits = [
            float(x) for x in result["answer_type_logits"].flat
        ]
        all_results.append(
            tydi_modeling.RawResult(
                unique_id=unique_id,
                start_logits=start_logits,
                end_logits=end_logits,
                answer_type_logits=answer_type_logits))

        # Allow `None` or `0` to disable this behavior.
        if self.max_to_predict and len(all_results) == self.max_to_predict:
          logging.info(
              "WARNING: Stopping predictions early since "
              "max_to_predict == %d", self.max_to_predict)
          break

      # Reads the prediction candidates from the (entire) prediction input file.
      candidates_dict = self.read_candidates(self.predict_file)
      predict_features = []
      for shard_filename in tf.gfile.Glob(shard_filename_glob):
        for r in tf.python_io.tf_record_iterator(shard_filename):
          predict_features.append(tf.train.Example.FromString(r))
      logging.info("Shard %d: Post-processing predictions.", shard_num)
      logging.info("  Num candidate examples loaded (includes all shards): %d",
                   len(candidates_dict))
      logging.info("  Num candidate features loaded: %d", len(predict_features))
      logging.info("  Num prediction result features: %d", len(all_results))
      logging.info("  Num shard features: %d", shard_num_features)

      tydi_pred_dict = postproc.compute_pred_dict(
          candidates_dict,
          predict_features, [r._asdict() for r in all_results],
          candidate_beam=self.candidate_beam,
          max_answer_length=self.max_answer_length)

      logging.info("Shard %d: Post-processed predictions.", shard_num)
      logging.info("  Num shard examples: %d", shard_num_examples)
      logging.info("  Num post-processed results: %d", len(tydi_pred_dict))
      if shard_num_examples != len(tydi_pred_dict):
        logging.warning("  Num missing predictions: %d",
                        shard_num_examples - len(tydi_pred_dict))
      for key, value in tydi_pred_dict.items():
        if key in full_tydi_pred_dict:
          logging.warning("ERROR: '%s' already in full_tydi_pred_dict!", key)
        full_tydi_pred_dict[key] = value

    logging.info("Prediction finished for all shards.")
    logging.info("  Total input examples: %d", total_num_examples)
    logging.info("  Total output predictions: %d", len(full_tydi_pred_dict))

    with tf.gfile.Open(self.output_prediction_file, "w") as output_file:
      for prediction in full_tydi_pred_dict.values():
        output_file.write((json.dumps(prediction) + "\n").encode())