def test_ConfusionMatrixMetric1(self): pred_dict = {"pred": torch.zeros(4, 3)} target_dict = {'target': torch.zeros(4)} metric = ConfusionMatrixMetric() metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric())
def test_ConfusionMatrixMetric6(self): # (6) check map, match metric = ConfusionMatrixMetric(pred='predictions', target='targets') pred_dict = {"predictions": torch.randn(4, 3, 2)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) res = metric.get_metric() print(res)
def test_ConfusionMatrixMetric4(self): # (4) check reset metric = ConfusionMatrixMetric() pred_dict = {"pred": torch.randn(4, 3, 2)} target_dict = {'target': torch.ones(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) res = metric.get_metric() self.assertTrue(isinstance(res, dict)) print(res)
def test_vocab(self): vocab = Vocabulary() word_list = "this is a word list".split() vocab.update(word_list) pred_dict = {"pred": torch.zeros(4, 3)} target_dict = {'target': torch.zeros(4)} metric = ConfusionMatrixMetric(vocab=vocab) metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric())
def test_duplicate(self): # 0.4.1的潜在bug,不能出现形参重复的情况 metric = ConfusionMatrixMetric(pred='predictions', target='targets') pred_dict = { "predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred': 0 } target_dict = {'targets': torch.zeros(4, 3), 'target': 0} metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric())
def test_ConfusionMatrixMetric8(self): # (8) check _fast_metric with self.assertRaises(Exception): metric = ConfusionMatrixMetric() pred_dict = { "predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3 } target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric())
def test_ConfusionMatrixMetric3(self): # (3) the second batch is corrupted size with self.assertRaises(Exception): metric = ConfusionMatrixMetric() pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4)} metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric())
def test_ConfusionMatrixMetric2(self): # (2) with corrupted size with self.assertRaises(Exception): pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4)} metric = ConfusionMatrixMetric() metric( pred_dict=pred_dict, target_dict=target_dict, ) print(metric.get_metric())
def test_ConfusionMatrixMetric5(self): # (5) check numpy array is not acceptable with self.assertRaises(Exception): metric = ConfusionMatrixMetric() pred_dict = {"pred": np.zeros((4, 3, 2))} target_dict = {'target': np.zeros((4, 3))} metric(pred_dict=pred_dict, target_dict=target_dict)
def test_seq_len(self): N = 256 seq_len = torch.zeros(N).long() seq_len[0] = 2 pred = {'pred': torch.ones(N, 2)} target = {'target': torch.ones(N, 2), 'seq_len': seq_len} metric = ConfusionMatrixMetric() metric(pred_dict=pred, target_dict=target) metric.get_metric(reset=False) seq_len[1:] = 1 metric(pred_dict=pred, target_dict=target) metric.get_metric()
def test_ConfusionMatrixMetric7(self): # (7) check map, include unused metric = ConfusionMatrixMetric(pred='prediction', target='targets') pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict)