def test_samples_with_targets(self): # Test the case where all of the TargetTextCollection contain targets test_collection = TargetTextCollection(self._target_text_examples()) sub_collection = test_collection.samples_with_targets() assert test_collection == sub_collection assert len(test_collection) == 3 # Test the case where none of the TargetTextCollection contain targets for sample_id in list(test_collection.keys()): del test_collection[sample_id] test_collection.add(TargetText(text='nothing', text_id=sample_id)) assert len(test_collection) == 3 sub_collection = test_collection.samples_with_targets() assert len(sub_collection) == 0 assert sub_collection != test_collection # Test the case where only a 2 of the the three TargetTextCollection # contain targets. test_collection = TargetTextCollection(self._target_text_examples()) del test_collection['another_id'] easy_case = TargetText(text='something else', text_id='another_id') test_collection.add(easy_case) sub_collection = test_collection.samples_with_targets() assert len(sub_collection) == 2 assert sub_collection != test_collection # Test the case where the targets are just an empty list rather than # None test_collection = TargetTextCollection(self._target_text_examples()) del test_collection['another_id'] edge_case = TargetText(text='something else', text_id='another_id', targets=[], spans=[]) test_collection.add(edge_case) sub_collection = test_collection.samples_with_targets() assert len(sub_collection) == 2 assert sub_collection != test_collection
def average_target_per_sentences(collection: TargetTextCollection, sentence_must_contain_targets: bool) -> float: ''' :param collection: Collection to calculate average target per sentence (ATS) on. :param sentence_must_contain_targets: Whether or not the sentences within the collection must contains at least one target. This filtering would affect the value of the dominator stated in the returns. :returns: The ATS for the given collection. Which is: Number of targets / number of sentences ''' number_targets = float(collection.number_targets()) if sentence_must_contain_targets: number_sentences = len(collection.samples_with_targets()) else: number_sentences = len(collection) return number_targets / float(number_sentences)