def get_knn_f1_score(embeddings_and_labels, sampler): query_embeddings, query_labels, reference_embeddings, reference_labels = tester.set_reference_and_query( embeddings_and_labels, 'train') for L in tester.label_levels: curr_query_labels = query_labels[:, L] curr_reference_labels = reference_labels[:, L] label_counts, num_k = get_label_counts(curr_reference_labels) embeddings_come_from_same_source = tester.embeddings_come_from_same_source( embeddings_and_labels) knn_indices, knn_distances = stat_utils.get_knn( reference_embeddings, query_embeddings, num_k, embeddings_come_from_same_source) knn_labels = curr_reference_labels[knn_indices] lone_query_labels = get_lone_query_labels( curr_query_labels, curr_reference_labels, label_counts, embeddings_come_from_same_source) not_lone_query_mask = ~np.isin(curr_query_labels, lone_query_labels) if not any(not_lone_query_mask): print( "Warning: None of the query labels are in the reference set and I barely know what that means." ) f1_scores = f1_score(curr_reference_labels, knn_labels[:, :1].flatten(), labels=sampler.labels, average=None) return f1_scores
def test_get_label_counts(self): label_counts, num_k = accuracy_calculator.get_label_counts( [0, 1, 3, 2, 3, 1, 3, 3, 4, 6, 5, 10, 4, 4, 4, 4, 6, 6, 5] ) self.assertTrue( label_counts == {0: 1, 1: 2, 2: 1, 3: 4, 4: 5, 5: 2, 6: 3, 10: 1} ) self.assertTrue(num_k == 5)
def test_get_lone_query_labels(self): query_labels = np.array([0, 1, 2, 3, 4, 5, 6]) reference_labels = np.array([0, 0, 0, 1, 2, 2, 3, 4, 5, 6]) reference_label_counts, _ = accuracy_calculator.get_label_counts( reference_labels) lone_query_labels = accuracy_calculator.get_lone_query_labels( query_labels, reference_labels, reference_label_counts, True) self.assertTrue( np.all(np.unique(lone_query_labels) == np.array([1, 3, 4, 5, 6]))) query_labels = np.array([0, 1, 2, 3, 4]) reference_labels = np.array([0, 0, 0, 1, 2, 2, 4, 5, 6]) lone_query_labels = accuracy_calculator.get_lone_query_labels( query_labels, reference_labels, reference_label_counts, False) self.assertTrue(np.all(np.unique(lone_query_labels) == np.array([3])))