def get_knn_f1_score(embeddings_and_labels, sampler):
            query_embeddings, query_labels, reference_embeddings, reference_labels = tester.set_reference_and_query(
                embeddings_and_labels, 'train')

            for L in tester.label_levels:
                curr_query_labels = query_labels[:, L]
                curr_reference_labels = reference_labels[:, L]
                label_counts, num_k = get_label_counts(curr_reference_labels)
                embeddings_come_from_same_source = tester.embeddings_come_from_same_source(
                    embeddings_and_labels)
                knn_indices, knn_distances = stat_utils.get_knn(
                    reference_embeddings, query_embeddings, num_k,
                    embeddings_come_from_same_source)
                knn_labels = curr_reference_labels[knn_indices]

                lone_query_labels = get_lone_query_labels(
                    curr_query_labels, curr_reference_labels, label_counts,
                    embeddings_come_from_same_source)
                not_lone_query_mask = ~np.isin(curr_query_labels,
                                               lone_query_labels)
                if not any(not_lone_query_mask):
                    print(
                        "Warning: None of the query labels are in the reference set and I barely know what that means."
                    )

                f1_scores = f1_score(curr_reference_labels,
                                     knn_labels[:, :1].flatten(),
                                     labels=sampler.labels,
                                     average=None)
                return f1_scores
Beispiel #2
0
 def test_get_label_counts(self):
     label_counts, num_k = accuracy_calculator.get_label_counts(
         [0, 1, 3, 2, 3, 1, 3, 3, 4, 6, 5, 10, 4, 4, 4, 4, 6, 6, 5]
     )
     self.assertTrue(
         label_counts == {0: 1, 1: 2, 2: 1, 3: 4, 4: 5, 5: 2, 6: 3, 10: 1}
     )
     self.assertTrue(num_k == 5)
    def test_get_lone_query_labels(self):
        query_labels = np.array([0, 1, 2, 3, 4, 5, 6])
        reference_labels = np.array([0, 0, 0, 1, 2, 2, 3, 4, 5, 6])
        reference_label_counts, _ = accuracy_calculator.get_label_counts(
            reference_labels)

        lone_query_labels = accuracy_calculator.get_lone_query_labels(
            query_labels, reference_labels, reference_label_counts, True)
        self.assertTrue(
            np.all(np.unique(lone_query_labels) == np.array([1, 3, 4, 5, 6])))

        query_labels = np.array([0, 1, 2, 3, 4])
        reference_labels = np.array([0, 0, 0, 1, 2, 2, 4, 5, 6])

        lone_query_labels = accuracy_calculator.get_lone_query_labels(
            query_labels, reference_labels, reference_label_counts, False)
        self.assertTrue(np.all(np.unique(lone_query_labels) == np.array([3])))