def get_knn_f1_score(embeddings_and_labels, sampler): query_embeddings, query_labels, reference_embeddings, reference_labels = tester.set_reference_and_query( embeddings_and_labels, 'train') for L in tester.label_levels: curr_query_labels = query_labels[:, L] curr_reference_labels = reference_labels[:, L] label_counts, num_k = get_label_counts(curr_reference_labels) embeddings_come_from_same_source = tester.embeddings_come_from_same_source( embeddings_and_labels) knn_indices, knn_distances = stat_utils.get_knn( reference_embeddings, query_embeddings, num_k, embeddings_come_from_same_source) knn_labels = curr_reference_labels[knn_indices] lone_query_labels = get_lone_query_labels( curr_query_labels, curr_reference_labels, label_counts, embeddings_come_from_same_source) not_lone_query_mask = ~np.isin(curr_query_labels, lone_query_labels) if not any(not_lone_query_mask): print( "Warning: None of the query labels are in the reference set and I barely know what that means." ) f1_scores = f1_score(curr_reference_labels, knn_labels[:, :1].flatten(), labels=sampler.labels, average=None) return f1_scores
def test_get_lone_query_labels(self): query_labels = np.array([0, 1, 2, 3, 4, 5, 6]) reference_labels = np.array([0, 0, 0, 1, 2, 2, 3, 4, 5, 6]) reference_label_counts, _ = accuracy_calculator.get_label_counts( reference_labels) lone_query_labels = accuracy_calculator.get_lone_query_labels( query_labels, reference_labels, reference_label_counts, True) self.assertTrue( np.all(np.unique(lone_query_labels) == np.array([1, 3, 4, 5, 6]))) query_labels = np.array([0, 1, 2, 3, 4]) reference_labels = np.array([0, 0, 0, 1, 2, 2, 4, 5, 6]) lone_query_labels = accuracy_calculator.get_lone_query_labels( query_labels, reference_labels, reference_label_counts, False) self.assertTrue(np.all(np.unique(lone_query_labels) == np.array([3])))
def test_get_lone_query_labels_multi_dim(self): def custom_label_comparison_fn(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1]) query_labels = np.array([ (1, 3), (0, 3), (0, 3), (0, 3), (0, 2), (1, 2), (4, 5), ]) for comparison_fn in [ accuracy_calculator.EQUALITY, custom_label_comparison_fn ]: label_counts, num_k = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) if comparison_fn is accuracy_calculator.EQUALITY: correct = [ ( True, np.array([[0, 2], [1, 2], [1, 3], [4, 5]]), np.array( [False, True, True, True, False, False, False]), ), ( False, np.array([[]]), np.array([True, True, True, True, True, True, True]), ), ] else: correct_lone = np.array([[4, 5]]) correct_mask = np.array( [True, True, True, True, True, True, False]) correct = [ (True, correct_lone, correct_mask), (False, correct_lone, correct_mask), ] for same_source, correct_lone, correct_mask in correct: ( lone_query_labels, not_lone_query_mask, ) = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, same_source, comparison_fn) if correct_lone.size == 0: self.assertTrue(lone_query_labels.size == 0) else: self.assertTrue(np.all(lone_query_labels == correct_lone)) self.assertTrue(np.all(not_lone_query_mask == correct_mask))
def test_get_lone_query_labels(self): query_labels = np.array([0, 1, 2, 3, 4, 5, 6]) reference_labels = np.array([0, 0, 0, 1, 2, 2, 3, 4, 5]) label_counts, _ = accuracy_calculator.get_label_match_counts( query_labels, reference_labels, accuracy_calculator.EQUALITY, ) for same_source, correct in [ (True, np.array([1, 3, 4, 5, 6])), (False, np.array([6])), ]: lone_query_labels, _ = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, same_source, accuracy_calculator.EQUALITY, ) self.assertTrue(np.all(lone_query_labels == correct))
def test_get_lone_query_labels(self): query_labels = torch.tensor([0, 1, 2, 3, 4, 5, 6], device=TEST_DEVICE) reference_labels = torch.tensor([0, 0, 0, 1, 2, 2, 3, 4, 5], device=TEST_DEVICE) label_counts = accuracy_calculator.get_label_match_counts( query_labels, reference_labels, accuracy_calculator.EQUALITY, ) for same_source, correct in [ (True, torch.tensor([1, 3, 4, 5, 6], device=TEST_DEVICE)), (False, torch.tensor([6], device=TEST_DEVICE)), ]: lone_query_labels, _ = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, same_source, accuracy_calculator.EQUALITY, ) self.assertTrue(torch.all(lone_query_labels == correct))
def test_get_lone_query_labels_multi_dim(self): def equality2D(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] == y[..., 1]) def custom_label_comparison_fn(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1]) query_labels = torch.tensor( [ (1, 3), (0, 3), (0, 3), (0, 3), (1, 2), (4, 5), ], device=TEST_DEVICE, ) for comparison_fn in [equality2D, custom_label_comparison_fn]: label_counts = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) unique_labels, counts = label_counts correct_unique_labels = torch.tensor( [[0, 3], [1, 2], [1, 3], [4, 5]], device=TEST_DEVICE) if comparison_fn is equality2D: correct_counts = torch.tensor([3, 1, 1, 1], device=TEST_DEVICE) else: correct_counts = torch.tensor([0, 1, 1, 0], device=TEST_DEVICE) self.assertTrue(torch.all(correct_counts == counts)) self.assertTrue(torch.all(correct_unique_labels == unique_labels)) if comparison_fn is equality2D: correct = [ ( True, torch.tensor([[1, 2], [1, 3], [4, 5]], device=TEST_DEVICE), torch.tensor([False, True, True, True, False, False], device=TEST_DEVICE), ), ( False, torch.tensor([[]], device=TEST_DEVICE), torch.tensor([True, True, True, True, True, True], device=TEST_DEVICE), ), ] else: correct = [ ( True, torch.tensor([[0, 3], [4, 5]], device=TEST_DEVICE), torch.tensor([True, False, False, False, True, False], device=TEST_DEVICE), ), ( False, torch.tensor([[0, 3], [4, 5]], device=TEST_DEVICE), torch.tensor([True, False, False, False, True, False], device=TEST_DEVICE), ), ] for same_source, correct_lone, correct_mask in correct: ( lone_query_labels, not_lone_query_mask, ) = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, same_source, comparison_fn) if correct_lone.numel() == 0: self.assertTrue(lone_query_labels.numel() == 0) else: self.assertTrue( torch.all(lone_query_labels == correct_lone)) self.assertTrue(torch.all(not_lone_query_mask == correct_mask))
def test_get_lone_query_labels_custom(self): def fn1(x, y): return abs(x - y) < 2 def fn2(x, y): return abs(x - y) > 99 query_labels = torch.tensor([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100], device=TEST_DEVICE) for comparison_fn in [fn1, fn2]: correct_unique_labels = torch.tensor( [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100], device=TEST_DEVICE) if comparison_fn is fn1: correct_counts = torch.tensor( [3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1], device=TEST_DEVICE) correct_lone_query_labels = torch.tensor([100], device=TEST_DEVICE) correct_not_lone_query_mask = torch.tensor( [ True, True, True, True, True, True, True, True, True, True, True, False, ], device=TEST_DEVICE, ) elif comparison_fn is fn2: correct_counts = torch.tensor( [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], device=TEST_DEVICE) correct_lone_query_labels = torch.tensor( [1, 2, 3, 4, 5, 6, 7, 8, 9], device=TEST_DEVICE) correct_not_lone_query_mask = torch.tensor( [ True, True, False, False, False, False, False, False, False, False, False, True, ], device=TEST_DEVICE, ) label_counts = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) unique_labels, counts = label_counts self.assertTrue(torch.all(unique_labels == correct_unique_labels)) self.assertTrue(torch.all(counts == correct_counts)) ( lone_query_labels, not_lone_query_mask, ) = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, True, comparison_fn) self.assertTrue( torch.all(lone_query_labels == correct_lone_query_labels)) self.assertTrue( torch.all(not_lone_query_mask == correct_not_lone_query_mask))
def test_get_lone_query_labels_custom(self): def fn1(x, y): return abs(x - y) < 2 def fn2(x, y): return abs(x - y) > 99 query_labels = np.array([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) for comparison_fn in [fn1, fn2]: correct_unique_labels = np.array( [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) if comparison_fn is fn1: correct_counts = np.array([3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1]) correct_lone_query_labels = np.array([100]) correct_not_lone_query_mask = np.array([ True, True, True, True, True, True, True, True, True, True, True, False, ]) elif comparison_fn is fn2: correct_counts = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]) correct_lone_query_labels = np.array( [1, 2, 3, 4, 5, 6, 7, 8, 9]) correct_not_lone_query_mask = np.array([ True, True, False, False, False, False, False, False, False, False, False, True, ]) label_counts, num_k = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) unique_labels, counts = label_counts self.assertTrue(np.all(unique_labels == correct_unique_labels)) self.assertTrue(np.all(counts == correct_counts)) ( lone_query_labels, not_lone_query_mask, ) = accuracy_calculator.get_lone_query_labels( query_labels, label_counts, True, comparison_fn) self.assertTrue( np.all(lone_query_labels == correct_lone_query_labels)) self.assertTrue( np.all(not_lone_query_mask == correct_not_lone_query_mask))