Ejemplo n.º 1
0
    def validate_end(self, outputs, num_samples):
        conf_mat = sum([o['conf_mat'] for o in outputs])
        val_loss = torch.stack([o['val_loss']
                                for o in outputs]).sum() / num_samples
        conf_mat_metrics = compute_conf_mat_metrics(conf_mat,
                                                    self.cfg.data.class_names)

        metrics = {'val_loss': val_loss.item()}
        metrics.update(conf_mat_metrics)

        return metrics
Ejemplo n.º 2
0
 def test2(self):
     label_names = ['a', 'b']
     conf_mat = torch.tensor([[0, 2.], [2, 0]])
     metrics = compute_conf_mat_metrics(conf_mat, label_names)
     exp_metrics = {
         'avg_precision': 0.0,
         'avg_recall': 0.0,
         'avg_f1': 0.0,
         'a_precision': 0.0,
         'a_recall': 0.0,
         'a_f1': 0.0,
         'b_precision': 0.0,
         'b_recall': 0.0,
         'b_f1': 0.0
     }
     self.assertDictEqual(metrics, exp_metrics)
Ejemplo n.º 3
0
    def test3(self):
        label_names = ['a', 'b']
        conf_mat = torch.tensor([[1, 2], [1, 2.]])
        metrics = compute_conf_mat_metrics(conf_mat, label_names, eps=0.0)

        def f1(prec, rec):
            return 2 * (prec * rec) / (prec + rec)

        def mean(a, b):
            return (a + b) / 2

        def round_dict(d):
            return dict([(k, round(v, 3)) for k, v in d.items()])

        a_prec = 1 / 2
        a_rec = 1 / 3
        a_f1 = f1(a_prec, a_rec)
        b_prec = 2 / 4
        b_rec = 2 / 3
        b_f1 = f1(b_prec, b_rec)
        avg_prec = mean(a_prec, b_prec)
        avg_rec = mean(a_rec, b_rec)
        avg_f1 = f1(avg_prec, avg_rec)

        exp_metrics = {
            'avg_precision': avg_prec,
            'avg_recall': avg_rec,
            'avg_f1': avg_f1,
            'a_precision': a_prec,
            'a_recall': a_rec,
            'a_f1': a_f1,
            'b_precision': b_prec,
            'b_recall': b_rec,
            'b_f1': b_f1
        }
        self.assertDictEqual(round_dict(metrics), round_dict(exp_metrics))