def test_init_from_ann_file(self): """ Tests initialization from valid ann file :return: """ annotations = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') self.assertIsNotNone(annotations.get_entity_annotations())
def compute_counts(self): """ Computes entity and relation counts over all documents in this dataset. :return: a dictionary of entity and relation counts. """ dataset_counts = {'entities': {}, 'relations': {}} for data_file in self: annotation = Annotations(data_file.ann_path) annotation_counts = annotation.compute_counts() dataset_counts['entities'] = { x: dataset_counts['entities'].get(x, 0) + annotation_counts['entities'].get(x, 0) for x in set(dataset_counts['entities']).union( annotation_counts['entities']) } dataset_counts['relations'] = { x: dataset_counts['relations'].get(x, 0) + annotation_counts['relations'].get(x, 0) for x in set(dataset_counts['relations']).union( annotation_counts['relations']) } return dataset_counts
def test_get_entity_annotations_list(self): """ Tests the validity of annotation list :return: """ annotations = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') self.assertIsInstance(annotations.get_entity_annotations(), list)
def test_get_entity_annotations_dict(self): """ Tests the validity of the annotation dict :return: """ annotations = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') self.assertIsInstance(annotations.get_entity_annotations(return_dictionary=True), dict)
def test_difference(self): """Tests that when a given Annotations object uses the diff() method with another Annotations object created from the same source file, that it returns an empty list.""" annotations1 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') annotations2 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') result = annotations1.difference(annotations2) self.assertFalse(result)
def compute_ambiguity(self, dataset): """ Finds occurrences of spans from 'dataset' that intersect with a span from this annotation but do not have this spans label. label. If 'dataset' comprises a models predictions, this method provides a strong indicators of a model's in-ability to dis-ambiguate between entities. For a full analysis, compute a confusion matrix. :param dataset: a Dataset object containing a predicted version of this dataset. :param leniency: a floating point value between [0,1] defining the leniency of the character spans to count as different. A value of zero considers only exact character matches while a positive value considers entities that differ by up to :code:`ceil(leniency * len(span)/2)` on either side. :return: a dictionary containing the ambiguity computations on each gold, predicted file pair """ if not isinstance(dataset, Dataset): raise ValueError("dataset must be instance of Dataset") # verify files are consistent diff = set([file.ann_path.split(os.sep)[-1] for file in self]).difference(set([file.ann_path.split(os.sep)[-1] for file in dataset])) if diff: raise ValueError("Dataset of predictions is missing the files: " + str(list(diff))) #Dictionary storing ambiguity over dataset ambiguity_dict = {} for gold_data_file in self: prediction_iter = iter(dataset) prediction_data_file = next(prediction_iter) while str(gold_data_file) != str(prediction_data_file): prediction_data_file = next(prediction_iter) gold_annotation = Annotations(gold_data_file.ann_path) pred_annotation = Annotations(prediction_data_file.ann_path) # compute matrix on the Annotation file level ambiguity_dict[str(gold_data_file)] = gold_annotation.compute_ambiguity(pred_annotation) return ambiguity_dict
def test_different_file_diff(self): """Tests that when two different files are used in the diff() method, either ValueError is raised (because the two Annotations cannot be compared) or that the output is a list with more than one value.""" annotations1 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') annotations2 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[1]), annotation_type='ann') result = annotations1.difference(annotations2) self.assertTrue(result.__len__() > 0)
def test_stats_returns_accurate_dict(self): """ Tests that when stats() is run on an annotation with three unique entities, the list of keys in the entity counts is three. """ annotations = Annotations(self.ann_file_path_two) # Use file two because it contains three unique entities stats = annotations.stats() num_entities = stats["entity_counts"] self.assertEqual(len(num_entities), 3)
def test_good_con_data(self): """Tests to see if valid con data can be used to instantiate an Annotations object.""" with open(join(self.test_dir, "test_con.con"), 'w+') as c,\ open(join(self.test_dir, "test_con_text.txt"), 'w+') as t: c.write(con_text) t.write(con_source_text) annotations = Annotations(c.name, annotation_type='con', source_text_path=t.name) self.assertIsInstance(annotations.get_entity_annotations(), list)
def test_compare_by_index_wrong_type(self): """ Tests that when a valid Annotations object calls the compare_by_entity() method with an object that is not an Annotations, it raises ValueError. """ annotations1 = Annotations(self.ann_file_path_one, annotation_type='ann') annotations2 = "This is not an Annotations object" with self.assertRaises(ValueError): annotations1.compare_by_index(annotations2)
def test_same_file_diff(self): """ Tests that when a given Annotations object uses the diff() method with another Annotations object created from the same source file, that it returns an empty list. """ annotations1 = Annotations(self.ann_file_path_one, annotation_type='ann') annotations2 = Annotations(self.ann_file_path_one, annotation_type='ann') result = annotations1.diff(annotations2) self.assertEqual(result, [])
def test_compare_by_index_stats_equivalent_annotations_avg_accuracy_one(self): """ Tests that when compare_by_index_stats() is called on two Annotations representing the same dataset, the average accuracy is one. """ annotations1 = Annotations(self.ann_file_path_one) annotations2 = Annotations(self.ann_file_path_one) comparison = annotations1.compare_by_index_stats(annotations2) actual = comparison["avg_accuracy"] self.assertEqual(actual, 1)
def test_compare_by_index_stats_nearly_identical_annotations_avg_accuracy(self): """ Tests that when compare_by_index_stats() is called with nearly identical Annotations, the average accuracy is less than 1. """ annotations1 = Annotations(self.ann_file_path_one_modified) annotations2 = Annotations(self.ann_file_path_one) comparison = annotations1.compare_by_index_stats(annotations2) accuracy = comparison["avg_accuracy"] self.assertLess(accuracy, 1)
def test_compare_by_index_all_entities_matched(self): """ Tests that when compare_by_index() is called with nearly-identical data, the calculations accurately pair annotations that refer to the same instance of an entity but are not identical. Specifically, test that list "NOT_MATCHED" is empty. """ annotations1 = Annotations(self.ann_file_path_one_modified) annotations2 = Annotations(self.ann_file_path_one) comparison = annotations1.compare_by_index(annotations2, strict=1) self.assertListEqual(comparison["NOT_MATCHED"], [])
def test_different_file_diff(self): """ Tests that when two different files with the same number of annotations are used in the diff() method, the output is a list with more than one value. """ # Note that both of these files contain ten annotations annotations1 = Annotations(self.ann_file_path_one, annotation_type='ann') annotations2 = Annotations(self.ann_file_path_two, annotation_type='ann') result = annotations1.diff(annotations2) self.assertGreater(len(result), 0)
def test_confusion_matrix(self): annotations1 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') annotations2 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[1]), annotation_type='ann') annotations1.add_entity(*annotations2.get_entity_annotations()[0]) self.assertEqual(len(annotations1.compute_confusion_matrix(annotations2, self.entities)[0]), len(self.entities)) self.assertEqual(len(annotations1.compute_confusion_matrix(annotations2, self.entities)), len(self.entities))
def test_to_html_with_source_text(self): """ Tests that when to_html() is called on a valid object with an output file path defined, that the outputted HTML file is created (thus, that the method ran to completion). """ if isfile(self.html_output_file_path): raise BaseException("This test requires that the html output file does not exist when the test is" " initiated, but it already exists.") annotations = Annotations(self.ann_file_path_one, annotation_type='ann', source_text_path=self.ann_file_one_source_path) annotations.to_html(self.html_output_file_path) self.assertTrue(isfile(self.html_output_file_path))
def test_ann_conversions(self): """Tests converting and un-converting a valid Annotations object to an ANN file.""" annotations = Annotations(self.ann_file_path_one, annotation_type='ann') annotations.to_ann(write_location=join(self.test_dir,"intermediary.ann")) annotations2 = Annotations(join(self.test_dir, "intermediary.ann"), annotation_type='ann') self.assertEqual(annotations.get_entity_annotations(return_dictionary=True), annotations2.get_entity_annotations(return_dictionary=True) )
def get_labels(self): """ Get all of the entities/labels used in the dataset. :return: A set of strings. Each string is a label used. """ labels = set() data_files = self.all_data_files for datafile in data_files: ann_path = datafile.get_annotation_path() annotations = Annotations(ann_path) labels.update(annotations.get_labels()) return labels
def test_init_from_broken_ann_file(self): """ Tests initialization from a correctly structured but ill-formated ann file :return: """ with self.assertRaises(InvalidAnnotationError): annotations = Annotations(join(self.test_dir, "broken_ann_file.ann"), annotation_type='ann')
def test_init_from_invalid_ann(self): """ Tests initialization from invalid annotation file :return: """ with self.assertRaises(FileNotFoundError): annotations = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0][:-1]), annotation_type='ann')
def test_init_from_non_dict_or_string(self): """ Tests initialization from non-dictionary or string :return: """ with self.assertRaises(TypeError): annotations = Annotations(list(), annotation_type='ann')
def test_init_from_dict(self): """ Tests initalization from a dictionary :return: """ annotations = Annotations({'entities':{}, 'relations':[]}) self.assertIsInstance(annotations, Annotations)
def test_init_from_invalid_dict(self): """ Tests initialization from invalid dict file :return: """ with self.assertRaises(InvalidAnnotationError): annotations = Annotations({})
def compute_confusion_matrix(self, dataset, leniency=0): """ Generates a confusion matrix where this Dataset serves as the gold standard annotations and `dataset` serves as the predicted annotations. A typical workflow would involve creating a Dataset object with the prediction directory outputted by a model and then passing it into this method. :param dataset: a Dataset object containing a predicted version of this dataset. :param leniency: a floating point value between [0,1] defining the leniency of the character spans to count as different. A value of zero considers only exact character matches while a positive value considers entities that differ by up to :code:`ceil(leniency * len(span)/2)` on either side. :return: two element tuple containing a label array (of entity names) and a matrix where rows are gold labels and columns are predicted labels. matrix[i][j] indicates that entities[i] in this dataset was predicted as entities[j] in 'annotation' matrix[i][j] times """ if not isinstance(dataset, Dataset): raise ValueError("dataset must be instance of Dataset") #verify files are consistent diff = set( [file.ann_path.split(os.sep)[-1] for file in self]).difference( set([file.ann_path.split(os.sep)[-1] for file in dataset])) if diff: raise ValueError("Dataset of predictions is missing the files: " + str(list(diff))) #sort entities in ascending order by count. entities = [ key for key, _ in sorted(self.compute_counts()['entities'].items(), key=lambda x: x[1]) ] confusion_matrix = [[0 for x in range(len(entities))] for x in range(len(entities))] for gold_data_file in self: prediction_iter = iter(dataset) prediction_data_file = next(prediction_iter) while str(gold_data_file) != str(prediction_data_file): prediction_data_file = next(prediction_iter) gold_annotation = Annotations(gold_data_file.ann_path) pred_annotation = Annotations(prediction_data_file.ann_path) #compute matrix on the Annotation file level ann_confusion_matrix = gold_annotation.compute_confusion_matrix( pred_annotation, entities, leniency=leniency) for i in range(len(confusion_matrix)): for j in range(len(confusion_matrix)): confusion_matrix[i][j] += ann_confusion_matrix[i][j] return entities, confusion_matrix
def test_good_con_data_without_text(self): """Tests to see if not having a source text file will raise FileNotFoundError.""" with open(join(self.test_dir, "test_con.con"), 'w+') as c: c.write(con_text) with self.assertRaises(FileNotFoundError): Annotations(c.name, annotation_type='con', source_text_path=None)
def test_compute_ambiguity(self): annotations1 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') annotations2 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') label, start, end, text = annotations2.get_entity_annotations()[0] annotations2.add_entity('incorrect_label', start, end, text) self.assertEqual(len(annotations1.compute_ambiguity(annotations2)), 1)
def test_bad_con_data(self): """Tests to see if invalid con data will raise InvalidAnnotationError.""" with open(join(self.test_dir, "bad_con.con"), 'w+') as c,\ open(join(self.test_dir, "test_con_text.txt"), 'w+') as t: c.write("This string wishes it was a valid con file.") t.write("It doesn't matter what's in this file as long as it exists.") Annotations(c.name, annotation_type='con', source_text_path=t.name) self.assertRaises(InvalidAnnotationError)
def test_intersection(self): annotations1 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[0]), annotation_type='ann') annotations2 = Annotations(join(self.dataset.get_data_directory(), self.ann_files[1]), annotation_type='ann') annotations1.add_entity(*annotations2.get_entity_annotations()[0]) annotations1.add_entity(*annotations2.get_entity_annotations()[1]) self.assertEqual( annotations1.intersection(annotations2), set([ annotations2.get_entity_annotations()[0], annotations2.get_entity_annotations()[1] ]))
def get_training_data(self, data_format='spacy'): """ Get training data in a specified format. :param data_format: The specified format as a string. :return: The requested data in the requested format. """ # Only spaCy format is currently supported. if data_format != 'spacy': raise TypeError("Format %s not supported" % format) training_data = [] # Add each entry in dataset with annotation to train_data for data_file in self.all_data_files: txt_path = data_file.get_text_path() ann_path = data_file.get_annotation_path() annotations = Annotations(ann_path, source_text_path=txt_path) training_data.append( annotations.get_entity_annotations(format='spacy')) return training_data