def test_get_lone_query_labels_multi_dim(self): def custom_label_comparison_fn(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1]) query_labels = np.array([ (1, 3), (0, 3), (0, 3), (0, 3), (0, 2), (1, 2), (4, 5), ]) for comparison_fn in [ accuracy_calculator.EQUALITY, custom_label_comparison_fn ]: label_counts, num_k = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) if comparison_fn is accuracy_calculator.EQUALITY: correct = [ ( True, np.array([[0, 2], [1, 2], [1, 3], [4, 5]]), np.array( [False, True, True, True, False, False, False]), ), ( False, np.array([[]]), np.array([True, True, True, True, True, True, True]), ), ] else: correct_lone = np.array([[4, 5]]) correct_mask = np.array( [True, True, True, True, True, True, False]) correct = [ (True, correct_lone, correct_mask), (False, correct_lone, correct_mask), ] for same_source, correct_lone, correct_mask in correct: ( lone_query_labels, not_lone_query_mask, ) = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, same_source, comparison_fn) if correct_lone.size == 0: self.assertTrue(lone_query_labels.size == 0) else: self.assertTrue(np.all(lone_query_labels == correct_lone)) self.assertTrue(np.all(not_lone_query_mask == correct_mask))
def test_get_lone_query_labels(self): query_labels = np.array([0, 1, 2, 3, 4, 5, 6]) reference_labels = np.array([0, 0, 0, 1, 2, 2, 3, 4, 5]) label_counts, _ = accuracy_calculator.get_label_match_counts( query_labels, reference_labels, accuracy_calculator.EQUALITY, ) for same_source, correct in [ (True, np.array([1, 3, 4, 5, 6])), (False, np.array([6])), ]: lone_query_labels, _ = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, same_source, accuracy_calculator.EQUALITY, ) self.assertTrue(np.all(lone_query_labels == correct))
def test_get_lone_query_labels(self): query_labels = torch.tensor([0, 1, 2, 3, 4, 5, 6], device=TEST_DEVICE) reference_labels = torch.tensor([0, 0, 0, 1, 2, 2, 3, 4, 5], device=TEST_DEVICE) label_counts = accuracy_calculator.get_label_match_counts( query_labels, reference_labels, accuracy_calculator.EQUALITY, ) for same_source, correct in [ (True, torch.tensor([1, 3, 4, 5, 6], device=TEST_DEVICE)), (False, torch.tensor([6], device=TEST_DEVICE)), ]: lone_query_labels, _ = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, same_source, accuracy_calculator.EQUALITY, ) self.assertTrue(torch.all(lone_query_labels == correct))
def test_get_lone_query_labels_multi_dim(self): def equality2D(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] == y[..., 1]) def custom_label_comparison_fn(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1]) query_labels = torch.tensor( [ (1, 3), (0, 3), (0, 3), (0, 3), (1, 2), (4, 5), ], device=TEST_DEVICE, ) for comparison_fn in [equality2D, custom_label_comparison_fn]: label_counts = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) unique_labels, counts = label_counts correct_unique_labels = torch.tensor( [[0, 3], [1, 2], [1, 3], [4, 5]], device=TEST_DEVICE) if comparison_fn is equality2D: correct_counts = torch.tensor([3, 1, 1, 1], device=TEST_DEVICE) else: correct_counts = torch.tensor([0, 1, 1, 0], device=TEST_DEVICE) self.assertTrue(torch.all(correct_counts == counts)) self.assertTrue(torch.all(correct_unique_labels == unique_labels)) if comparison_fn is equality2D: correct = [ ( True, torch.tensor([[1, 2], [1, 3], [4, 5]], device=TEST_DEVICE), torch.tensor([False, True, True, True, False, False], device=TEST_DEVICE), ), ( False, torch.tensor([[]], device=TEST_DEVICE), torch.tensor([True, True, True, True, True, True], device=TEST_DEVICE), ), ] else: correct = [ ( True, torch.tensor([[0, 3], [4, 5]], device=TEST_DEVICE), torch.tensor([True, False, False, False, True, False], device=TEST_DEVICE), ), ( False, torch.tensor([[0, 3], [4, 5]], device=TEST_DEVICE), torch.tensor([True, False, False, False, True, False], device=TEST_DEVICE), ), ] for same_source, correct_lone, correct_mask in correct: ( lone_query_labels, not_lone_query_mask, ) = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, same_source, comparison_fn) if correct_lone.numel() == 0: self.assertTrue(lone_query_labels.numel() == 0) else: self.assertTrue( torch.all(lone_query_labels == correct_lone)) self.assertTrue(torch.all(not_lone_query_mask == correct_mask))
def test_get_lone_query_labels_custom(self): def fn1(x, y): return abs(x - y) < 2 def fn2(x, y): return abs(x - y) > 99 query_labels = torch.tensor([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100], device=TEST_DEVICE) for comparison_fn in [fn1, fn2]: correct_unique_labels = torch.tensor( [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100], device=TEST_DEVICE) if comparison_fn is fn1: correct_counts = torch.tensor( [3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1], device=TEST_DEVICE) correct_lone_query_labels = torch.tensor([100], device=TEST_DEVICE) correct_not_lone_query_mask = torch.tensor( [ True, True, True, True, True, True, True, True, True, True, True, False, ], device=TEST_DEVICE, ) elif comparison_fn is fn2: correct_counts = torch.tensor( [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], device=TEST_DEVICE) correct_lone_query_labels = torch.tensor( [1, 2, 3, 4, 5, 6, 7, 8, 9], device=TEST_DEVICE) correct_not_lone_query_mask = torch.tensor( [ True, True, False, False, False, False, False, False, False, False, False, True, ], device=TEST_DEVICE, ) label_counts = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) unique_labels, counts = label_counts self.assertTrue(torch.all(unique_labels == correct_unique_labels)) self.assertTrue(torch.all(counts == correct_counts)) ( lone_query_labels, not_lone_query_mask, ) = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, True, comparison_fn) self.assertTrue( torch.all(lone_query_labels == correct_lone_query_labels)) self.assertTrue( torch.all(not_lone_query_mask == correct_not_lone_query_mask))
def test_get_lone_query_labels_custom(self): def fn1(x, y): return abs(x - y) < 2 def fn2(x, y): return abs(x - y) > 99 query_labels = np.array([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) for comparison_fn in [fn1, fn2]: correct_unique_labels = np.array( [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) if comparison_fn is fn1: correct_counts = np.array([3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1]) correct_lone_query_labels = np.array([100]) correct_not_lone_query_mask = np.array([ True, True, True, True, True, True, True, True, True, True, True, False, ]) elif comparison_fn is fn2: correct_counts = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]) correct_lone_query_labels = np.array( [1, 2, 3, 4, 5, 6, 7, 8, 9]) correct_not_lone_query_mask = np.array([ True, True, False, False, False, False, False, False, False, False, False, True, ]) label_counts, num_k = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) unique_labels, counts = label_counts self.assertTrue(np.all(unique_labels == correct_unique_labels)) self.assertTrue(np.all(counts == correct_counts)) ( lone_query_labels, not_lone_query_mask, ) = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, True, comparison_fn) self.assertTrue( np.all(lone_query_labels == correct_lone_query_labels)) self.assertTrue( np.all(not_lone_query_mask == correct_not_lone_query_mask))