Exemplo n.º 1
0
 def test_seq_len(self):
     N = 256
     seq_len = torch.zeros(N).long()
     seq_len[0] = 2
     pred = {'pred': torch.ones(N, 2)}
     target = {'target': torch.ones(N, 2), 'seq_len': seq_len}
     metric = ConfusionMatrixMetric()
     metric(pred_dict=pred, target_dict=target)
     metric.get_metric(reset=False)
     seq_len[1:] = 1
     metric(pred_dict=pred, target_dict=target)
     metric.get_metric()
Exemplo n.º 2
0
    def test_ConfusionMatrixMetric1(self):
        pred_dict = {"pred": torch.zeros(4, 3)}
        target_dict = {'target': torch.zeros(4)}
        metric = ConfusionMatrixMetric()

        metric(pred_dict=pred_dict, target_dict=target_dict)
        print(metric.get_metric())
Exemplo n.º 3
0
 def test_ConfusionMatrixMetric6(self):
     # (6) check map, match
     metric = ConfusionMatrixMetric(pred='predictions', target='targets')
     pred_dict = {"predictions": torch.randn(4, 3, 2)}
     target_dict = {'targets': torch.zeros(4, 3)}
     metric(pred_dict=pred_dict, target_dict=target_dict)
     res = metric.get_metric()
     print(res)
Exemplo n.º 4
0
 def test_ConfusionMatrixMetric4(self):
     # (4) check reset
     metric = ConfusionMatrixMetric()
     pred_dict = {"pred": torch.randn(4, 3, 2)}
     target_dict = {'target': torch.ones(4, 3)}
     metric(pred_dict=pred_dict, target_dict=target_dict)
     res = metric.get_metric()
     self.assertTrue(isinstance(res, dict))
     print(res)
Exemplo n.º 5
0
    def test_vocab(self):
        vocab = Vocabulary()
        word_list = "this is a word list".split()
        vocab.update(word_list)

        pred_dict = {"pred": torch.zeros(4, 3)}
        target_dict = {'target': torch.zeros(4)}
        metric = ConfusionMatrixMetric(vocab=vocab)
        metric(pred_dict=pred_dict, target_dict=target_dict)
        print(metric.get_metric())
Exemplo n.º 6
0
 def test_duplicate(self):
     # 0.4.1的潜在bug,不能出现形参重复的情况
     metric = ConfusionMatrixMetric(pred='predictions', target='targets')
     pred_dict = {
         "predictions": torch.zeros(4, 3, 2),
         "seq_len": torch.ones(4) * 3,
         'pred': 0
     }
     target_dict = {'targets': torch.zeros(4, 3), 'target': 0}
     metric(pred_dict=pred_dict, target_dict=target_dict)
     print(metric.get_metric())
Exemplo n.º 7
0
 def test_ConfusionMatrixMetric8(self):
     # (8) check _fast_metric
     with self.assertRaises(Exception):
         metric = ConfusionMatrixMetric()
         pred_dict = {
             "predictions": torch.zeros(4, 3, 2),
             "seq_len": torch.ones(3) * 3
         }
         target_dict = {'targets': torch.zeros(4, 3)}
         metric(pred_dict=pred_dict, target_dict=target_dict)
         print(metric.get_metric())
Exemplo n.º 8
0
    def test_ConfusionMatrixMetric3(self):
        # (3) the second batch is corrupted size
        with self.assertRaises(Exception):
            metric = ConfusionMatrixMetric()
            pred_dict = {"pred": torch.zeros(4, 3, 2)}
            target_dict = {'target': torch.zeros(4, 3)}
            metric(pred_dict=pred_dict, target_dict=target_dict)

            pred_dict = {"pred": torch.zeros(4, 3, 2)}
            target_dict = {'target': torch.zeros(4)}
            metric(pred_dict=pred_dict, target_dict=target_dict)

            print(metric.get_metric())
Exemplo n.º 9
0
    def test_ConfusionMatrixMetric2(self):
        # (2) with corrupted size

        with self.assertRaises(Exception):
            pred_dict = {"pred": torch.zeros(4, 3, 2)}
            target_dict = {'target': torch.zeros(4)}
            metric = ConfusionMatrixMetric()

            metric(
                pred_dict=pred_dict,
                target_dict=target_dict,
            )
            print(metric.get_metric())