コード例 #1
0
def text_eval(encoder,
              features_iter,
              model_dir,
              global_step,
              eval_tag,
              enable_logging,
              inputs_pattern="^inputs[0-9]*$",
              targets_key="targets",
              predictions_key="outputs",
              additional_keys=(),
              num_reserved=None):
  """Evaluates a set of text targets/predictions."""
  decode_fn = lambda x: ids2str(encoder, x, num_reserved)
  scorers_dict = {}
  scorers_dict[_ROUGE_METRIC] = rouge_scorer.RougeScorer(
    ["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=True)
  scorers_dict[_BLEU_METRIC] = bleu_scorer.BleuScorer()
  scorers_dict[_REPETITION_METRIC] = repetition_scorer.RepetitionScorer(
    ["regs1", "regs2", "regs3", "regsTCR"])
  scorers_dict[_LENGTH_METRIC] = length_scorer.LengthScorer(["word", "char"])
  aggregators_dict = {k: scoring.BootstrapAggregator() for k in scorers_dict}

  with LogWriter(additional_keys, model_dir, global_step, eval_tag,
                 enable_logging) as log_writer:
    for i, features in enumerate(features_iter):
      inputs_list = []
      for k in sorted(features):
        if re.match(inputs_pattern, k):
          single_inputs = decode_matrix(decode_fn, features[k])
          if isinstance(single_inputs, list):
            inputs_list.extend(single_inputs)
          else:
            inputs_list.append(single_inputs)

      inputs = "\n".join(inputs_list)
      targets = decode_fn(features[targets_key])
      preds = decode_fn(features[predictions_key])
      text_dict = {
        "inputs": inputs_list,
        "targets": targets,
        "predictions": preds
      }

      for key in additional_keys:
        if key == "selected_ids":
          text_dict[key] = decode_selected_indices(decode_fn, features)
        else:
          text_dict[key] = decode_matrix(decode_fn, features[key])

      log_writer.write(text_dict, i)

      for key, scorer in scorers_dict.items():
        scores_i = scorer.score(targets, preds)
        aggregators_dict[key].add_scores(scores_i)

  aggregates_dict = {k: v.aggregate() for k, v in aggregators_dict.items()}
  length_histograms = scorers_dict[_LENGTH_METRIC].histograms(as_string=True)
  _write_aggregates(model_dir, global_step, eval_tag, aggregates_dict,
                    length_histograms)
  _write_aggregate_summaries(model_dir, global_step, eval_tag, aggregates_dict)
コード例 #2
0
def text_eval(preds_file,
              model_dir,
              global_step: int = 0,
              eval_tag: str = "",
              enable_logging: bool = True):
    """Evaluates a set of text targets/predictions."""
    scorers_dict = {
        _ROUGE_METRIC:
        rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL", "rougeLsum"],
                                 use_stemmer=True),
        _BLEU_METRIC:
        bleu_scorer.BleuScorer(),
        _REPETITION_METRIC:
        repetition_scorer.RepetitionScorer(
            ["regs1", "regs2", "regs3", "regsTCR"]),
        _LENGTH_METRIC:
        length_scorer.LengthScorer(["word", "char"])
    }
    aggregators_dict = {k: scoring.BootstrapAggregator() for k in scorers_dict}

    with LogWriter((), model_dir, 0, "", enable_logging) as log_writer:
        with open(preds_file) as csv_file:
            reader = csv.DictReader(csv_file)
            for i, row in enumerate(reader):
                text_dict = {
                    "inputs": row['prompt'],
                    "targets": row['targets'],
                    "predictions": row['predictions']
                }

                log_writer.write(text_dict, i)

                for key, scorer in scorers_dict.items():
                    scores_i = scorer.score(row['targets'], row['predictions'])
                    aggregators_dict[key].add_scores(scores_i)

    aggregates_dict = {k: v.aggregate() for k, v in aggregators_dict.items()}
    length_histograms = scorers_dict[_LENGTH_METRIC].histograms(as_string=True)
    _write_aggregates(model_dir, global_step, eval_tag, aggregates_dict,
                      length_histograms)
    _write_aggregate_summaries(model_dir, global_step, eval_tag,
                               aggregates_dict)