Exemplo n.º 1
0
    def test_AccuaryMetric5(self):
        # (5) check reset
        metric = AccuracyMetric()
        pred_dict = {"pred": torch.zeros(4, 3, 2)}
        target_dict = {'target': torch.zeros(4, 3)}
        metric(pred_dict=pred_dict, target_dict=target_dict)
        self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1})

        pred_dict = {"pred": torch.zeros(4, 3, 2)}
        target_dict = {'target': torch.zeros(4, 3) + 1}
        metric(pred_dict=pred_dict, target_dict=target_dict)
        self.assertDictEqual(metric.get_metric(), {'acc': 0.5})
Exemplo n.º 2
0
 def test_AccuaryMetric7(self):
     # (7) check map, match
     metric = AccuracyMetric(pred='predictions', target='targets')
     pred_dict = {"predictions": torch.zeros(4, 3, 2)}
     target_dict = {'targets': torch.zeros(4, 3)}
     metric(pred_dict=pred_dict, target_dict=target_dict)
     self.assertDictEqual(metric.get_metric(), {'acc': 1})
Exemplo n.º 3
0
    def test_AccuracyMetric1(self):
        # (1) only input, targets passed
        pred_dict = {"pred": torch.zeros(4, 3)}
        target_dict = {'target': torch.zeros(4)}
        metric = AccuracyMetric()

        metric(pred_dict=pred_dict, target_dict=target_dict, )
        print(metric.get_metric())
Exemplo n.º 4
0
 def test_AccuaryMetric7(self):
     # (7) check map, match
     metric = AccuracyMetric(pred='predictions', target='targets')
     pred_dict = {"predictions": torch.randn(4, 3, 2)}
     target_dict = {'targets': torch.zeros(4, 3)}
     metric(pred_dict=pred_dict, target_dict=target_dict)
     res = metric.get_metric()
     ans = (torch.argmax(
         pred_dict["predictions"],
         dim=2).float() == target_dict["targets"]).float().mean()
     self.assertAlmostEqual(res["acc"], float(ans), places=4)
Exemplo n.º 5
0
 def test_AccuaryMetric5(self):
     # (5) check reset
     metric = AccuracyMetric()
     pred_dict = {"pred": torch.randn(4, 3, 2)}
     target_dict = {'target': torch.zeros(4, 3)}
     metric(pred_dict=pred_dict, target_dict=target_dict)
     res = metric.get_metric(reset=False)
     ans = (torch.argmax(
         pred_dict["pred"],
         dim=2).float() == target_dict["target"]).float().mean()
     self.assertAlmostEqual(res["acc"], float(ans), places=4)
Exemplo n.º 6
0
 def test_AccuaryMetric4(self):
     # (5) check reset
     metric = AccuracyMetric()
     pred_dict = {"pred": torch.randn(4, 3, 2)}
     target_dict = {'target': torch.ones(4, 3)}
     metric(pred_dict=pred_dict, target_dict=target_dict)
     ans = torch.argmax(pred_dict["pred"], dim=2).to(
         target_dict["target"]) == target_dict["target"]
     res = metric.get_metric()
     self.assertTrue(isinstance(res, dict))
     self.assertTrue("acc" in res)
     self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3)
Exemplo n.º 7
0
 def test_AccuaryMetric9(self):
     # (9) check map, include unused
     try:
         metric = AccuracyMetric(pred='prediction', target='targets')
         pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1}
         target_dict = {'targets': torch.zeros(4, 3)}
         metric(pred_dict=pred_dict, target_dict=target_dict)
         self.assertDictEqual(metric.get_metric(), {'acc': 1})
     except Exception as e:
         print(e)
         return
     self.assertTrue(True, False), "No exception catches."
Exemplo n.º 8
0
 def test_AccuaryMetric8(self):
     # (8) check map, does not match. use stop_fast_param to stop fast param map
     try:
         metric = AccuracyMetric(pred='predictions', target='targets')
         pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1}
         target_dict = {'targets': torch.zeros(4, 3)}
         metric(pred_dict=pred_dict, target_dict=target_dict, )
         self.assertDictEqual(metric.get_metric(), {'acc': 1})
     except Exception as e:
         print(e)
         return
     self.assertTrue(True, False), "No exception catches."
Exemplo n.º 9
0
 def test_AccuaryMetric10(self):
     # (10) check _fast_metric
     try:
         metric = AccuracyMetric()
         pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)}
         target_dict = {'targets': torch.zeros(4, 3)}
         metric(pred_dict=pred_dict, target_dict=target_dict)
         self.assertDictEqual(metric.get_metric(), {'acc': 1})
     except Exception as e:
         print(e)
         return
     self.assertTrue(True, False), "No exception catches."
Exemplo n.º 10
0
    def test_AccuracyMetric2(self):
        # (2) with corrupted size
        try:
            pred_dict = {"pred": torch.zeros(4, 3, 2)}
            target_dict = {'target': torch.zeros(4)}
            metric = AccuracyMetric()

            metric(pred_dict=pred_dict, target_dict=target_dict, )
            print(metric.get_metric())
        except Exception as e:
            print(e)
            return
        self.assertTrue(True, False), "No exception catches."