Beispiel #1
0
 def report(self):
     """Report the metrics over all data seen so far."""
     m = {}
     total = self.metrics['cnt']
     m['exs'] = total
     if total > 0:
         if self.flags['print_prediction_metrics']:
             if 'accuracy' in self.metrics_list:
                 m['accuracy'] = round_sigfigs(
                     self.metrics['correct'] /
                     max(1, self.metrics['correct_cnt']), 4)
             if 'f1' in self.metrics_list:
                 m['f1'] = round_sigfigs(
                     self.metrics['f1'] / max(1, self.metrics['f1_cnt']), 4)
         if self.flags['has_text_cands']:
             for k in self.eval_pr:
                 m['hits@' + str(k)] = round_sigfigs(
                     self.metrics['hits@' + str(k)] /
                     max(1, self.metrics['hits@_cnt']),
                     3,
                 )
         for k in self.metrics_list:
             if self.metrics[k +
                             '_cnt'] > 0 and k != 'correct' and k != 'f1':
                 m[k] = round_sigfigs(
                     self.metrics[k] / max(1, self.metrics[k + '_cnt']), 4)
     return m
Beispiel #2
0
 def report(self):
     metrics = {}
     # heads up, if you have multiple optimizers, or different parameter
     # groups, this could be misleading
     current_lr = round_sigfigs(self.optimizer.param_groups[0]['lr'], 4)
     metrics['lr'] = round_sigfigs(current_lr, 4)
     metrics['num_updates'] = self._number_training_updates
     return metrics
Beispiel #3
0
 def report(self, compute_time=None):
     m = {}
     with self._lock():
         m['total'] = self.metrics['total']
         if m['total'] > 0:
             # m['num_unk'] = self.metrics['num_unk']
             # m['num_tokens'] = self.metrics['num_tokens']
             m['loss'] = round_sigfigs(self.metrics['loss'] / self.metrics['num_tokens'], 3)
             m['ppl'] = round_sigfigs(math.exp(self.metrics['loss'] / self.metrics['num_tokens']), 4)
     return m
Beispiel #4
0
 def report(self):
     """Report per-dialogue round metrics."""
     m = {k: {} for k in ["first_round", "second_round", "third_round+"]}
     for k, v in self.metrics.items():
         if v["num_samples"] > 0:
             m[k]["hits@1/100"] = round_sigfigs(
                 v["hits@1/100"] / v["num_samples"], 4)
             m[k]["loss"] = round_sigfigs(v["loss"] / v["num_samples"], 4)
             if "med_rank" in v:
                 m[k]["med_rank"] = np.median(v["med_rank"])
     return m
Beispiel #5
0
    def test_round_sigfigs(self):
        x = 0
        y = 0
        assert round_sigfigs(x, 2) == y

        x = 100
        y = 100
        assert round_sigfigs(x, 2) == y

        x = 0.01
        y = 0.01
        assert round_sigfigs(x, 2) == y

        x = 0.00123
        y = 0.001
        assert round_sigfigs(x, 1) == y

        x = 0.37
        y = 0.4
        assert round_sigfigs(x, 1) == y

        x = 2353
        y = 2350
        assert round_sigfigs(x, 3) == y

        x = 3547345734
        y = 3547350000
        assert round_sigfigs(x, 6) == y

        x = 0.0000046246
        y = 0.00000462
        assert round_sigfigs(x, 3) == y
Beispiel #6
0
 def report(self):
     # Report the metrics over all data seen so far.
     m = {}
     m['total'] = self.metrics['cnt']
     if self.metrics['cnt'] > 0:
         m['accuracy'] = round_sigfigs(
             self.metrics['correct'] / self.metrics['cnt'], 4)
         m['f1'] = round_sigfigs(self.metrics['f1'] / self.metrics['cnt'],
                                 4)
         m['hits@k'] = {}
         for k in self.eval_pr:
             m['hits@k'][k] = round_sigfigs(
                 self.metrics['hits@' + str(k)] / self.metrics['cnt'], 4)
     return m
Beispiel #7
0
 def report(self):
     # Report the metrics over all data seen so far.
     m = {}
     m['total'] = self.metrics['cnt']
     if self.metrics['cnt'] > 0:
         m['accuracy'] = round_sigfigs(
             self.metrics['correct'] / self.metrics['cnt'], 4)
         m['f1'] = round_sigfigs(
             self.metrics['f1'] / self.metrics['cnt'], 4)
         m['hits@k'] = {}
         for k in self.eval_pr:
             m['hits@k'][k] = round_sigfigs(
                 self.metrics['hits@' + str(k)] / self.metrics['cnt'], 4)
     return m
Beispiel #8
0
 def report(self):
     # Report the metrics over all data seen so far.
     m = {}
     total = self.metrics['cnt']
     m['total'] = total
     if total > 0:
         m['accuracy'] = round_sigfigs(self.metrics['correct'] / total, 4)
         m['f1'] = round_sigfigs(self.metrics['f1'] / total, 4)
         m['hits@k'] = {}
         for k in self.eval_pr:
             m['hits@k'][k] = round_sigfigs(
                 self.metrics['hits@' + str(k)] / total, 3)
         for k in self.custom_keys:
             m[k] = round_sigfigs(self.metrics[k] / total, 3)
     return m
Beispiel #9
0
 def report(self):
     # Report the metrics over all data seen so far.
     m = {}
     total = self.metrics['cnt']
     m['total'] = total
     if total > 0:
         m['accuracy'] = round_sigfigs(self.metrics['correct'] / total, 4)
         m['f1'] = round_sigfigs(self.metrics['f1'] / total, 4)
         m['hits@k'] = {}
         for k in self.eval_pr:
             m['hits@k'][k] = round_sigfigs(
                 self.metrics['hits@' + str(k)] / total, 4)
         for k in self.custom_keys:
             m[k] = round_sigfigs(self.metrics[k] / total, 4)
     return m
Beispiel #10
0
    def report(self):
        """
        Report the current metrics.

        :return:
            a metrics dict
        """
        m = {}
        if self.metrics['num_samples'] > 0:
            m['hits@1/100'] = round_sigfigs(
                self.metrics['hits@1/100'] / self.metrics['num_samples'], 4)
            m['loss'] = round_sigfigs(
                self.metrics['loss'] / self.metrics['num_samples'], 4)
            if 'med_rank' in self.metrics:
                m['med_rank'] = np.median(self.metrics['med_rank'])
        return m
Beispiel #11
0
    def report(self):
        """
        Report loss and perplexity from model's perspective.

        Note that this includes predicting __END__ and __UNK__ tokens and may
        differ from a truly independent measurement.
        """
        base = super().report()
        m = {}
        m["num_tokens"] = self.counts["num_tokens"]
        m["num_batches"] = self.counts["num_batches"]
        m["loss"] = self.metrics["loss"] / m["num_batches"]
        m["base_loss"] = self.metrics["base_loss"] / m["num_batches"]
        m["kge_loss"] = self.metrics["kge_loss"] / m["num_batches"]
        m["l2_loss"] = self.metrics["l2_loss"] / m["num_batches"]
        m["acc"] = self.metrics["acc"] / m["num_tokens"]
        m["auc"] = self.metrics["auc"] / m["num_tokens"]
        # Top-k recommendation Recall
        for x in sorted(self.metrics):
            if x.startswith("recall") and self.counts[x] > 200:
                m[x] = self.metrics[x] / self.counts[x]
                m["num_tokens_" + x] = self.counts[x]
        # for x in ["1", "10", "50"]:
        #     if f"recall@{x}" in self.metrics and self.metrics[f"recall@{x}"] != []:
        #         m[f"recall@{x}"] = self.metrics[f"recall@{x}"] / m["num_tokens"]
        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            base[k] = round_sigfigs(v, 4)
        return base
Beispiel #12
0
    def report(self):
        """
        Report loss and perplexity from model's perspective.

        Note that this includes predicting __END__ and __UNK__ tokens and may
        differ from a truly independent measurement.
        """
        base = super().report()
        m = {}
        num_tok = self.metrics['num_tokens']
        if num_tok > 0:
            m['loss'] = self.metrics['loss']
            if self.metrics['correct_tokens'] > 0:
                m['token_acc'] = self.metrics['correct_tokens'] / num_tok
            m['nll_loss'] = self.metrics['nll_loss'] / num_tok
            try:
                m['ppl'] = math.exp(m['nll_loss'])
            except OverflowError:
                m['ppl'] = float('inf')
        if self.metrics['total_skipped_batches'] > 0:
            m['total_skipped_batches'] = self.metrics['total_skipped_batches']
        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            base[k] = round_sigfigs(v, 4)
        return base
Beispiel #13
0
    def report(self):
        # if we haven't initialized yet, just return a dummy object
        if not hasattr(self, "trainer"):
            return {}

        # These are the metrics we'll pass up the way, and their new names
        train_metrics = {"train_loss", "ups", "wps", "gnorm", "clip"}
        valid_metrics = {"valid_loss"}

        metrics = train_metrics if self.is_training else valid_metrics

        m = {k: self.trainer.meters[k].avg for k in metrics}

        # additionally output perplexity. note that fairseq models use base 2
        # in cross_entropy:
        # github.com/pytorch/fairseq/blob/master/fairseq/criterions/cross_entropy.py#L55
        if "train_loss" in m:
            m["train_ppl"] = np.exp2(m["train_loss"])
        if "valid_loss" in m:
            m["ppl"] = np.exp2(m["valid_loss"])

        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            m[k] = round_sigfigs(v, 4)

        return m
Beispiel #14
0
 def report(self):
     r = super().report()
     bsz = max(self.metrics['bsz'], 1)
     for k in ['know_loss', 'know_acc', 'know_chance']:
         # round and average across all items since last report
         r[k] = round_sigfigs(self.metrics[k] / bsz, 4)
     return r
Beispiel #15
0
    def report(self):
        """Report loss as well as precision, recall, and F1 metrics."""
        m = super().report()
        examples = self.metrics['examples']
        if examples > 0:
            m['examples'] = examples
            m['mean_loss'] = self.metrics['loss'] / examples

            # get prec/recall metrics
            confmat = self.metrics['confusion_matrix']
            if self.opt.get('get_all_metrics'):
                metrics_list = self.class_list
            else:
                # only give prec/recall metrics for ref class
                metrics_list = [self.ref_class]

            examples_per_class = []
            for class_i in metrics_list:
                class_total = self._report_prec_recall_metrics(confmat, class_i, m)
                examples_per_class.append(class_total)

            if len(examples_per_class) > 1:
                # get weighted f1
                f1 = 0
                total_exs = sum(examples_per_class)
                for i in range(len(self.class_list)):
                    f1 += (examples_per_class[i] / total_exs) * m[
                        'class_{}_f1'.format(self.class_list[i])
                    ]
                m['weighted_f1'] = f1

        for k, v in m.items():
            m[k] = round_sigfigs(v, 4)

        return m
Beispiel #16
0
def aggregate_metrics(reporters):
    # reporters is a list of teachers or worlds
    m = {}
    m['tasks'] = {}
    sums = {'accuracy': 0, 'f1': 0, 'loss': 0, 'ppl': 0}
    if nltkbleu is not None:
        sums['bleu'] = 0
    num_tasks = 0
    total = 0
    for i in range(len(reporters)):
        tid = reporters[i].getID()
        mt = reporters[i].report()
        while tid in m['tasks']:
            # prevent name cloberring if using multiple tasks with same ID
            tid += '_'
        m['tasks'][tid] = mt
        total += mt['exs']
        found_any = False
        for k in sums.keys():
            if k in mt:
                sums[k] += mt[k]
                found_any = True
        if found_any:
            num_tasks += 1
    m['exs'] = total
    m['accuracy'] = 0
    if num_tasks > 0:
        for k in sums.keys():
            m[k] = round_sigfigs(sums[k] / num_tasks, 4)
    return m
Beispiel #17
0
def eval_ppl(opt):
    """Evaluates the the perplexity and f1 of a model (and hits@1 if model has
    ranking enabled.
    """
    dict_agent = build_dict()

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent],
                        default_world=PerplexityWorld)
    world.dict = dict_agent

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['total'] / world.num_examples() * 100, 2),
                report))
            log_time.reset()
    if world.epoch_done():
        print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
 def _format_interactive_output(self, probs, prediction_id):
     """Format interactive mode output with scores."""
     preds = []
     for i, pred_id in enumerate(prediction_id.tolist()):
         prob = round_sigfigs(probs[i][pred_id], 4)
         preds.append('Predicted class: {}\nwith probability: {}'.format(
             self.class_list[pred_id], prob))
     return preds
Beispiel #19
0
 def report(self):
     m = {}
     if self.metrics['num_tokens'] > 0:
         m['loss'] = self.metrics['loss'] / self.metrics['num_tokens']
         m['ppl'] = math.exp(m['loss'])
     for k, v in m.items():
         # clean up: rounds to sigfigs and converts tensors to floats
         m[k] = round_sigfigs(v, 4)
     return m
Beispiel #20
0
def aggregate_task_reports(reports, tasks, micro=False):
    """
    Aggregate separate task reports into a single report.

    :param reports: list of report dicts from separate tasks
    :param tasks: list of tasks
    :param micro: average per example if True, else average over t

    :return: aggregated report dicts
    """
    if len(reports) == 1:
        # singular task
        return reports[0]
    # multiple tasks, aggregate metrics
    metrics = {}
    exs = {}
    total_report = {'tasks': {}}
    # collect metrics from all reports
    for i, report in enumerate(reports):
        total_report['tasks'][tasks[i]] = report
        for metric, val in report.items():
            if metric == 'exs':
                exs[tasks[i]] = val
            else:
                metrics.setdefault(metric, {})[tasks[i]] = val
    # now aggregate
    total_exs = sum(exs.values())
    total_report['exs'] = total_exs
    for metric, task_vals in metrics.items():
        if all([isinstance(v, Number) for v in task_vals.values()]):
            if micro:
                # average over the number of examples
                vals = [task_vals[task] * exs[task] for task in tasks]
                total_report[metric] = round_sigfigs(sum(vals) / total_exs, 4)
            else:  # macro
                # average over tasks
                vals = task_vals.values()
                total_report[metric] = round_sigfigs(sum(vals) / len(vals), 4)
    # add a warning describing how metrics were averaged across tasks.
    total_report['warning'] = 'metrics are averaged across tasks'
    if micro:
        total_report[
            'warning'] += ' and weighted by the number of examples ' 'per task'
    return total_report
Beispiel #21
0
 def _nice_format(self, dictionary):
     rounded = {}
     for k, v in dictionary.items():
         if isinstance(v, dict):
             rounded[k] = self._nice_format(v)
         elif isinstance(v, float):
             rounded[k] = round_sigfigs(v, 4)
         else:
             rounded[k] = v
     return rounded
 def _format_interactive_output(self, probs, prediction_id):
     """Nicely format interactive mode output when we want to also
     print the scores associated with the predictions.
     """
     preds = []
     for i, pred_id in enumerate(prediction_id.tolist()):
         prob = round_sigfigs(probs[i][pred_id], 4)
         preds.append('Predicted class: {}\nwith probability: {}'.format(
                      self.class_list[pred_id], prob))
     return preds
Beispiel #23
0
 def report(self):
     m = {}
     m['loss'] = self.metrics['loss']
     ranks = np.asarray(self.metrics['r@'])
     m['r@1'] = len(np.where(ranks < 1)[0]) / len(ranks)
     m['r@5'] = len(np.where(ranks < 5)[0]) / len(ranks)
     m['r@10'] = len(np.where(ranks < 10)[0]) / len(ranks)
     for k, v in m.items():
         # clean up: rounds to sigfigs and converts tensors to floats
         m[k] = round_sigfigs(v, 4)
     return m
Beispiel #24
0
 def report(self):
     # Report the metrics over all data seen so far.
     m = {}
     total = self.metrics['cnt']
     m['total'] = total
     if total > 0:
         m['accuracy'] = round_sigfigs(self.metrics['correct'] / total, 4)
         m['f1'] = round_sigfigs(self.metrics['f1'] / total, 4)
         m['hits@k'] = {}
         for k in self.eval_pr:
             m['hits@k'][k] = round_sigfigs(
                 self.metrics['hits@' + str(k)] / total, 3)
         for k in self.custom_keys:
             if k in self.metrics:
                 v = self.metrics[k]
                 if type(v) not in (int, float) or v != 0:
                     m[k] = round_sigfigs(v / total, 3)
         for k in self.custom_metrics:
             if self.metrics[k + '_cnt'] > 0:
                 m[k] = round_sigfigs(self.metrics[k] / self.metrics[k + '_cnt'], 4)
     return m
Beispiel #25
0
 def report(self):
     m = {}
     if self.metrics['num_tokens'] > 0:
         m['loss'] = self.metrics['loss'] / self.metrics['num_tokens']
         m['ppl'] = math.exp(m['loss'])
         m['d_1'] = len(self.metrics['1_gram']) / self.metrics['out_tokens']
         m['d_2'] = len(self.metrics['2_gram']) / self.metrics['out_tokens']
         m['BLEU'] = corpus_bleu(self.refs, self.hypos, smoothing_function=self.sf.method5)
     for k, v in m.items():
         # clean up: rounds to sigfigs and converts tensors to floats
         m[k] = round_sigfigs(v, 4)
     return m
Beispiel #26
0
 def report(self):
     """Report loss and mean_rank from model's perspective."""
     m = {}
     examples = self.metrics['examples']
     if examples > 0:
         m['examples'] = examples
         m['loss'] = self.metrics['loss']
         m['mean_loss'] = self.metrics['loss'] / examples
         m['mean_rank'] = self.metrics['rank'] / examples
     for k, v in m.items():
         # clean up: rounds to sigfigs and converts tensors to floats
         m[k] = round_sigfigs(v, 4)
     return m
Beispiel #27
0
 def report(self):
     # Report the metrics over all data seen so far.
     m = {}
     total = self.metrics['cnt']
     m['total'] = total
     if total > 0:
         if self.flags['print_prediction_metrics']:
             m['accuracy'] = round_sigfigs(
                 self.metrics['correct'] / self.metrics['correct_cnt'], 4)
             m['f1'] = round_sigfigs(
                 self.metrics['f1'] / self.metrics['f1_cnt'], 4)
             if self.flags['has_text_cands']:
                 m['hits@k'] = {}
                 for k in self.eval_pr:
                     m['hits@k'][k] = round_sigfigs(
                         self.metrics['hits@' + str(k)] /
                         self.metrics['hits@_cnt'], 3)
         for k in self.metrics_list:
             if self.metrics[k +
                             '_cnt'] > 0 and k != 'correct' and k != 'f1':
                 m[k] = round_sigfigs(
                     self.metrics[k] / self.metrics[k + '_cnt'], 4)
     return m
Beispiel #28
0
    def report(self):
        """Report loss and perplexity from model's perspective.

        Note that this includes predicting __END__ and __UNK__ tokens and may
        differ from a truly independent measurement.
        """
        m = {}
        if self.metrics['num_tokens'] > 0:
            m['loss'] = self.metrics['loss'] / self.metrics['num_tokens']
            m['ppl'] = math.exp(m['loss'])
        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            m[k] = round_sigfigs(v, 4)
        return m
Beispiel #29
0
    def report(self):
        m = super().report()
        if self.metrics['total_batches'] > 0:
            m['rank_loss'] = self.metrics['rank_loss']
        if self.num_predicted_count > 0:
            m['overlap_prediction'] = self.overlap_count[
                'predicted'] / self.num_predicted_count
            for i in range(5):
                m['overlap_ranked{}'.format(i)] = self.overlap_count[
                    'ranked{}'.format(i)] / self.num_predicted_count
            m['injected_ranked0'] = self.injpred_selected_count / self.pred_count
        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            m[k] = round_sigfigs(v, 4)

        return m
Beispiel #30
0
 def report(self):
     """Report loss and mean_rank from model's perspective."""
     base = super().report()
     m = {}
     examples = self.metrics['examples']
     if examples > 0:
         m['examples'] = examples
         m['loss'] = self.metrics['loss']
         m['mean_loss'] = self.metrics['loss'] / examples
         m['mean_rank'] = self.metrics['rank'] / examples
         if self.opt['candidates'] == 'batch':
             m['train_accuracy'] = self.metrics['train_accuracy'] / examples
     for k, v in m.items():
         # clean up: rounds to sigfigs and converts tensors to floats
         base[k] = round_sigfigs(v, 4)
     return base
Beispiel #31
0
    def __load_training_batch(self, observations):
        if observations and len(
                observations) > 0 and observations[0] and self.is_combine_attr:
            if not self.random_policy:
                with torch.no_grad():
                    current_states = self._build_states(observations)
                action_probs = self.policy(current_states)
                if self.action_log_time.time() > self.log_every_n_secs and len(
                        self.tasks) > 1:
                    with torch.no_grad():
                        # log the action distributions
                        action_p = ','.join([
                            str(round_sigfigs(x, 4))
                            for x in action_probs[0].data.tolist()
                        ])
                        log = '[ {} {} ]'.format('Action probs:', action_p)
                        print(log)
                        self.action_log_time.reset()
                sample_from = Categorical(action_probs[0])
                action = sample_from.sample()
                train_step = observations[0]['train_step']
                self.saved_actions[train_step] = sample_from.log_prob(action)
                self.saved_state_actions[train_step] = torch.cat(
                    [current_states, action_probs], dim=1)
                selected_task = action.item()
                self.subtask_counter[self.subtasks[selected_task]] += 1

                probs = action_probs[0].tolist()
                selection_report = {}
                for idx, t in enumerate(self.subtasks):
                    selection_report['p_{}'.format(t)] = probs[idx]
                    self.p_selections[t].append(probs[idx])
                    selection_report['c_{}'.format(
                        t)] = self.subtask_counter[t]
                    self.c_selections[t].append(self.subtask_counter[t])
                self.writer.add_metrics(setting='Teacher/task_selection',
                                        step=train_step,
                                        report=selection_report)
            else:
                selected_task = random.choice(range(len(self.tasks)))
                self.subtask_counter[self.subtasks[selected_task]] += 1
        else:
            selected_task = 0

        return self.__load_batch(observations, task_idx=selected_task)
Beispiel #32
0
 def report(self):
     """Report metrics from model's perspective."""
     m = TorchAgent.report(self)  # Skip TorchRankerAgent; totally redundant
     examples = self.metrics['examples']
     if examples > 0:
         m['examples'] = examples
         if 'dialog' in self.subtasks and self.metrics['dia_exs'] > 0:
             m['dia_loss'] = self.metrics['dia_loss'] / self.metrics[
                 'dia_exs']
             m['dia_rank'] = self.metrics['dia_rank'] / self.metrics[
                 'dia_exs']
             m['dia_acc'] = self.metrics['dia_correct'] / self.metrics[
                 'dia_exs']
             m['dia_exs'] = self.metrics['dia_exs']
         if 'feedback' in self.subtasks and self.metrics['fee_exs'] > 0:
             m['fee_loss'] = self.metrics['fee_loss'] / self.metrics[
                 'fee_exs']
             m['fee_rank'] = self.metrics['fee_rank'] / self.metrics[
                 'fee_exs']
             m['fee_acc'] = self.metrics['fee_correct'] / self.metrics[
                 'fee_exs']
             m['fee_exs'] = self.metrics['fee_exs']
             m['fee_exs'] = self.metrics['fee_exs']
         if 'satisfaction' in self.subtasks and self.metrics['sat_exs'] > 0:
             tp = self.metrics['sat_tp']
             tn = self.metrics['sat_tn']
             fp = self.metrics['sat_fp']
             fn = self.metrics['sat_fn']
             assert tp + tn + fp + fn == self.metrics['sat_exs']
             m['sat_loss'] = self.metrics['sat_loss'] / self.metrics[
                 'sat_exs']
             m['sat_pr'] = tp / (tp + fp + EPS)
             m['sat_re'] = tp / (tp + fn + EPS)
             pr = m['sat_pr']
             re = m['sat_re']
             m['sat_f1'] = (2 * pr * re) / (pr + re) if (pr and re) else 0.0
             m['sat_acc'] = (tp + tn) / self.metrics['sat_exs']
             m['sat_exs'] = self.metrics['sat_exs']
     for k, v in m.items():
         # clean up: rounds to sigfigs and converts tensors to floats
         if isinstance(v, float):
             m[k] = round_sigfigs(v, 4)
         else:
             m[k] = v
     return m