def text_eval(encoder, features_iter, model_dir, global_step, eval_tag, enable_logging, inputs_pattern="^inputs[0-9]*$", targets_key="targets", predictions_key="outputs", additional_keys=(), num_reserved=None): """Evaluates a set of text targets/predictions.""" decode_fn = lambda x: ids2str(encoder, x, num_reserved) scorers_dict = {} scorers_dict[_ROUGE_METRIC] = rouge_scorer.RougeScorer( ["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=True) scorers_dict[_BLEU_METRIC] = bleu_scorer.BleuScorer() scorers_dict[_REPETITION_METRIC] = repetition_scorer.RepetitionScorer( ["regs1", "regs2", "regs3", "regsTCR"]) scorers_dict[_LENGTH_METRIC] = length_scorer.LengthScorer(["word", "char"]) aggregators_dict = {k: scoring.BootstrapAggregator() for k in scorers_dict} with LogWriter(additional_keys, model_dir, global_step, eval_tag, enable_logging) as log_writer: for i, features in enumerate(features_iter): inputs_list = [] for k in sorted(features): if re.match(inputs_pattern, k): single_inputs = decode_matrix(decode_fn, features[k]) if isinstance(single_inputs, list): inputs_list.extend(single_inputs) else: inputs_list.append(single_inputs) inputs = "\n".join(inputs_list) targets = decode_fn(features[targets_key]) preds = decode_fn(features[predictions_key]) text_dict = { "inputs": inputs_list, "targets": targets, "predictions": preds } for key in additional_keys: if key == "selected_ids": text_dict[key] = decode_selected_indices(decode_fn, features) else: text_dict[key] = decode_matrix(decode_fn, features[key]) log_writer.write(text_dict, i) for key, scorer in scorers_dict.items(): scores_i = scorer.score(targets, preds) aggregators_dict[key].add_scores(scores_i) aggregates_dict = {k: v.aggregate() for k, v in aggregators_dict.items()} length_histograms = scorers_dict[_LENGTH_METRIC].histograms(as_string=True) _write_aggregates(model_dir, global_step, eval_tag, aggregates_dict, length_histograms) _write_aggregate_summaries(model_dir, global_step, eval_tag, aggregates_dict)
def text_eval(preds_file, model_dir, global_step: int = 0, eval_tag: str = "", enable_logging: bool = True): """Evaluates a set of text targets/predictions.""" scorers_dict = { _ROUGE_METRIC: rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=True), _BLEU_METRIC: bleu_scorer.BleuScorer(), _REPETITION_METRIC: repetition_scorer.RepetitionScorer( ["regs1", "regs2", "regs3", "regsTCR"]), _LENGTH_METRIC: length_scorer.LengthScorer(["word", "char"]) } aggregators_dict = {k: scoring.BootstrapAggregator() for k in scorers_dict} with LogWriter((), model_dir, 0, "", enable_logging) as log_writer: with open(preds_file) as csv_file: reader = csv.DictReader(csv_file) for i, row in enumerate(reader): text_dict = { "inputs": row['prompt'], "targets": row['targets'], "predictions": row['predictions'] } log_writer.write(text_dict, i) for key, scorer in scorers_dict.items(): scores_i = scorer.score(row['targets'], row['predictions']) aggregators_dict[key].add_scores(scores_i) aggregates_dict = {k: v.aggregate() for k, v in aggregators_dict.items()} length_histograms = scorers_dict[_LENGTH_METRIC].histograms(as_string=True) _write_aggregates(model_dir, global_step, eval_tag, aggregates_dict, length_histograms) _write_aggregate_summaries(model_dir, global_step, eval_tag, aggregates_dict)