Exemplo n.º 1
0
    def test_confusion_matrix(self):
        annotations1 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann')
        annotations2 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[1]), annotation_type='ann')
        annotations1.add_entity(*annotations2.get_entity_annotations()[0])

        self.assertEqual(len(annotations1.compute_confusion_matrix(annotations2, self.entities)[0]), len(self.entities))
        self.assertEqual(len(annotations1.compute_confusion_matrix(annotations2, self.entities)), len(self.entities))
Exemplo n.º 2
0
    def compute_confusion_matrix(self, dataset, leniency=0):
        """
        Generates a confusion matrix where this Dataset serves as the gold standard annotations and `dataset` serves
        as the predicted annotations. A typical workflow would involve creating a Dataset object with the prediction directory
        outputted by a model and then passing it into this method.

        :param dataset: a Dataset object containing a predicted version of this dataset.
        :param leniency: a floating point value between [0,1] defining the leniency of the character spans to count as different. A value of zero considers only exact character matches while a positive value considers entities that differ by up to :code:`ceil(leniency * len(span)/2)` on either side.
        :return: two element tuple containing a label array (of entity names) and a matrix where rows are gold labels and columns are predicted labels. matrix[i][j] indicates that entities[i] in this dataset was predicted as entities[j] in 'annotation' matrix[i][j] times
        """
        if not isinstance(dataset, Dataset):
            raise ValueError("dataset must be instance of Dataset")

        #verify files are consistent
        diff = set(
            [file.ann_path.split(os.sep)[-1] for file in self]).difference(
                set([file.ann_path.split(os.sep)[-1] for file in dataset]))
        if diff:
            raise ValueError("Dataset of predictions is missing the files: " +
                             str(list(diff)))

        #sort entities in ascending order by count.
        entities = [
            key for key, _ in sorted(self.compute_counts()['entities'].items(),
                                     key=lambda x: x[1])
        ]
        confusion_matrix = [[0 for x in range(len(entities))]
                            for x in range(len(entities))]

        for gold_data_file in self:
            prediction_iter = iter(dataset)
            prediction_data_file = next(prediction_iter)
            while str(gold_data_file) != str(prediction_data_file):
                prediction_data_file = next(prediction_iter)

            gold_annotation = Annotations(gold_data_file.ann_path)
            pred_annotation = Annotations(prediction_data_file.ann_path)

            #compute matrix on the Annotation file level
            ann_confusion_matrix = gold_annotation.compute_confusion_matrix(
                pred_annotation, entities, leniency=leniency)
            for i in range(len(confusion_matrix)):
                for j in range(len(confusion_matrix)):
                    confusion_matrix[i][j] += ann_confusion_matrix[i][j]

        return entities, confusion_matrix