예제 #1
0
def string_to_span(logger, span_string, where=None):
    pattern = re.compile('^\((\S+),(\S+)\)-\((\S+),(\S+)\)$')
    match = pattern.match(span_string)
    span = None
    if match:
        span = Span(logger, match.group(1), match.group(2), match.group(3), match.group(4))
    else:
        logger.record_event('UNEXPECTED_SPAN_FORMAT', span_string, where)
    return span
예제 #2
0
 def get_corrected_span(self, span):
     min_x, min_y, max_x, max_y = map(
         lambda arg: float(self.get(arg)),
         ['start_x', 'start_y', 'end_x', 'end_y'])
     sx, sy, ex, ey = map(lambda arg: float(span.get(arg)),
                          ['start_x', 'start_y', 'end_x', 'end_y'])
     sx = self.get('start_x') if sx < min_x else span.get('start_x')
     sy = self.get('start_y') if sy < min_y else span.get('start_y')
     ex = self.get('end_x') if ex > max_x else span.get('end_x')
     ey = self.get('end_y') if ey > max_y else span.get('end_y')
     return Span(self.get('logger'), sx, sy, ex, ey)
예제 #3
0
    def validate(self, span):
        """
        Validate if the span is inside the document boundary
        
        Arguments:
            span:
                span could be an aida.Span object, or a string of the
                form:
                    (start_x,start_y)-(end_x,end_y)

        Returns True if the span is inside the document, False otherwise.
        
        This method throws exception if span is not as mentioned above.
        """
        if isinstance(span, str):
            search_obj = re.search(r'^\((\d+),(\d+)\)-\((\d+),(\d+)\)$', span)
            if search_obj:
                start_x = search_obj.group(1)
                start_y = search_obj.group(2)
                end_x = search_obj.group(3)
                end_y = search_obj.group(4)
                span = Span(self.logger, start_x, start_y, end_x, end_y)
            else:
                raise Exception(
                    '{} is not of a form (start_x,start_y)-(end_x,end_y)'.
                    format(span))

        if isinstance(span, Span):
            min_x, min_y, max_x, max_y = map(
                lambda arg: float(self.get(arg)),
                ['start_x', 'start_y', 'end_x', 'end_y'])
            sx, sy, ex, ey = map(lambda arg: float(span.get(arg)),
                                 ['start_x', 'start_y', 'end_x', 'end_y'])
            is_valid = False
            if min_x <= sx <= max_x and min_x <= ex <= max_x and min_y <= sy <= max_y and min_y <= ey <= max_y:
                is_valid = True
            return is_valid
        else:
            raise TypeError(
                '{} called with argument of unexpected type'.format(
                    isinstance.__name__))
예제 #4
0
def spanstring_to_object(logger, span_string, where=None):
    pattern = re.compile('^(.*?):(.*?):\((\S+),(\S+)\)-\((\S+),(\S+)\)$')
    match = pattern.match(span_string)
    mention = Object(logger)
    if match:
        document_id = match.group(1)
        document_element_id, keyframe_id = parse_document_element_id(match.group(2))
        span = Span(logger, match.group(3), match.group(4), match.group(5), match.group(6))
        mention.set('span_string', span_string)
        mention.set('document_id', document_id)
        mention.set('document_element_id', document_element_id)
        mention.set('keyframe_id', keyframe_id)
        mention.set('span', span)
        mention.set('where', where)
    else:
        logger.record_event('UNEXPECTED_SPAN_FORMAT', span_string, where)
    return mention
예제 #5
0
 def get_corrected_span(self, span):
     min_x, min_y, max_x, max_y = map(
         lambda arg: float(self.get(arg)),
         ['start_x', 'start_y', 'end_x', 'end_y'])
     sx, sy, ex, ey = map(lambda arg: float(span.get(arg)),
                          ['start_x', 'start_y', 'end_x', 'end_y'])
     # if the span is (0,0)-(0,0) return document boundary
     if sx + sy + ex + ey == 0:
         return self.get('span')
     if sx > max_x or sy > max_y or ex < min_x or ey < min_y:
         # can't correct, return None
         return
     sx = self.get('start_x') if sx < min_x else span.get('start_x')
     sy = self.get('start_y') if sy < min_y else span.get('start_y')
     ex = self.get('end_x') if ex > max_x else span.get('end_x')
     ey = self.get('end_y') if ey > max_y else span.get('end_y')
     return Span(self.get('logger'), sx, sy, ex, ey)
예제 #6
0
 def get_span(self):
     sx, sy, ex, ey = [
         self.get(arg) for arg in ['start_x', 'start_y', 'end_x', 'end_y']
     ]
     return Span(self.get('logger'), sx, sy, ex, ey)
예제 #7
0
    def validate_value_provenance_triple(self, responses, schema, entry, attribute):
        where = entry.get('where')

        provenance = entry.get(attribute.get('name'))
        if len(provenance.split(':')) != 3:
            self.record_event('INVALID_PROVENANCE_FORMAT', provenance, where)
            return False

        pattern = re.compile('^(\w+?):(\w+?):\((\S+),(\S+)\)-\((\S+),(\S+)\)$')
        match = pattern.match(provenance)
        if not match:
            self.record_event('INVALID_PROVENANCE_FORMAT', provenance, where)
            return False

        document_id = match.group(1)
        document_element_id = match.group(2)
        start_x, start_y, end_x, end_y = map(lambda ID: match.group(ID), [3, 4, 5, 6])
        
        # if provided, obtain keyframe_id and update document_element_id
        pattern = re.compile('^(\w*?)_(\d+)$')
        match = pattern.match(document_element_id)
        keyframe_num = match.group(2) if match else None
        document_element_id = match.group(1) if match else document_element_id
        
        # check if the document element has file extension appended to it
        # if so, report warning, and apply correction
        extensions = tuple(['.' + extension for extension in responses.get('document_mappings').get('encodings')])
        if document_element_id.endswith(extensions):
            self.record_event('ID_WITH_EXTENSION', 'document element id', document_element_id, where)
            document_element_id = os.path.splitext(document_element_id)[0]
            provenance = '{}:{}:({},{})-({},{})'.format(document_id, document_element_id, start_x, start_y, end_x, end_y)
            entry.set(attribute.get('name'), provenance)
        
        if document_id != entry.get('document_id'):
            self.record_event('MULTIPLE_DOCUMENTS', document_id, entry.get('document_id'), where)
            return False
        
        documents = responses.get('document_mappings').get('documents')
        document_elements = responses.get('document_mappings').get('document_elements')
        
        if document_id not in documents:
            self.record_event('UNKNOWN_ITEM', 'document', document_id, where)
            return False
        document = documents.get(document_id)

        if document_element_id not in document_elements:
            self.record_event('UNKNOWN_ITEM', 'document element', document_element_id, where)
            return False
        document_element = document_elements.get(document_element_id)

        modality = document_element.get('modality')
        if modality is None:
            self.record_event('UNKNOWN_MODALITY', document_element_id, where)
            return False

        keyframe_id = None
        if modality == 'video':
            if keyframe_num:
                keyframe_id = '{}_{}'.format(document_element_id, keyframe_num)
                if keyframe_id not in responses.get('keyframe_boundaries'):
                    self.record_event('MISSING_ITEM_WITH_KEY', 'KeyFrameID', keyframe_id, where)
                    return False

        if not document.get('document_elements').exists(document_element_id):
            self.record_event('PARENT_CHILD_RELATION_FAILURE', document_element_id, document_id, where)
            return False

        for coordinate in [start_x, start_y, end_x, end_y]:
            if not is_number(coordinate):
                self.record_event('NOT_A_NUMBER', coordinate, where)
                return False
            if float(coordinate) < 0:
                self.record_event('NEGATIVE_NUMBER', coordinate, where)
                return False

        for start, end in [(start_x, end_x), (start_y, end_y)]:
            if float(start) > float(end):
                self.record_event('START_BIGGER_THAN_END', start, end, provenance, where)
                return False

        # An entry in the coreference metric output file is invalid if:
        #  (a) a video mention of an entity was asserted using VideoJustification, or
        #  (b) a video mention of an relation/event was asserted using KeyFrameVideoJustification
        if entry.get('schema').get('name') == 'AIDA_PHASE2_TASK1_CM_RESPONSE' and modality == 'video':
            if keyframe_id and entry.get('metatype') != 'Entity':
                self.record_event('UNEXPECTED_JUSTIFICATION', provenance, entry.get('metatype'), entry.get('cluster_id'), 'KeyFrameVideoJustification', entry.get('where'))
                return False
            elif not keyframe_id and entry.get('metatype') not in ['Relation', 'Event']:
                self.record_event('UNEXPECTED_JUSTIFICATION', provenance, entry.get('metatype'), entry.get('cluster_id'), 'VideoJustification', entry.get('where'))
                return False

        document_element_boundary = responses.get('{}_boundaries'.format('keyframe' if modality=='video' and keyframe_id else modality)).get(keyframe_id if modality == 'video' and keyframe_id else document_element_id)
        span = Span(self.logger, start_x, start_y, end_x, end_y)
        if not document_element_boundary.validate(span):
            corrected_span = document_element_boundary.get('corrected_span', span)
            corrected_provenance = '{}:{}:{}'.format(document_id, keyframe_id if keyframe_id else document_element_id, corrected_span.__str__())
            entry.set(attribute.get('name'), corrected_provenance)
            self.record_event('SPAN_OFF_BOUNDARY', span, document_element_boundary, document_element_id, where)
        return True
예제 #8
0
    def validate_provenance(self, responses, schema, entry, attribute_name,
                            provenance, apply_correction):
        where = entry.get('where')

        if schema.get(
                'task'
        ) == 'task3' and attribute_name == 'subject_informative_justification_span_text' and provenance == 'NULL':
            return True

        if schema.get(
                'task'
        ) == 'task3' and attribute_name == 'predicate_justification_spans_text' and provenance == 'NULL':
            return True

        if not self.validate_provenance_format(provenance, where):
            return False

        document_id, document_element_id, keyframe_num, start_x, start_y, end_x, end_y = self.parse_provenance(
            provenance)

        # check if the document element has file extension appended to it
        # if so, report warning, and apply correction
        extensions = tuple([
            '.' + extension for extension in responses.get(
                'document_mappings').get('encodings')
        ])
        if document_element_id.endswith(extensions):
            if apply_correction:
                self.record_event('ID_WITH_EXTENSION', 'document element id',
                                  document_element_id, where)
                document_element_id = os.path.splitext(document_element_id)[0]
                provenance = '{}:{}:({},{})-({},{})'.format(
                    document_id, document_element_id, start_x, start_y, end_x,
                    end_y)
                entry.set(attribute_name, provenance)
            else:
                self.record_event('ID_WITH_EXTENSION_ERROR',
                                  'document element id', document_element_id,
                                  where)
                return False

        if document_id != entry.get('document_id'):
            self.record_event('MULTIPLE_DOCUMENTS', document_id,
                              entry.get('document_id'), where)
            return False

        documents = responses.get('document_mappings').get('documents')
        document_elements = responses.get('document_mappings').get(
            'document_elements')

        if document_id not in documents:
            self.record_event('UNKNOWN_ITEM', 'document', document_id, where)
            return False
        document = documents.get(document_id)

        if document_element_id not in document_elements:
            self.record_event('UNKNOWN_ITEM', 'document element',
                              document_element_id, where)
            return False
        document_element = document_elements.get(document_element_id)

        modality = document_element.get('modality')
        if modality is None:
            self.record_event('UNKNOWN_MODALITY', document_element_id, where)
            return False

        keyframe_id = None
        if modality == 'video':
            if keyframe_num:
                keyframe_id = '{}_{}'.format(document_element_id, keyframe_num)
                if keyframe_id not in responses.get('keyframe_boundaries'):
                    self.record_event('MISSING_ITEM_WITH_KEY', 'KeyFrameID',
                                      keyframe_id, where)
                    return False

        if not document.get('document_elements').exists(document_element_id):
            self.record_event('PARENT_CHILD_RELATION_FAILURE',
                              document_element_id, document_id, where)
            return False

        if not self.validate_coordinates(provenance, start_x, start_y, end_x,
                                         end_y, where):
            return False

        # An entry in the coreference metric output file is invalid if:
        #  (a) a video mention of an entity was asserted using VideoJustification, or
        #  (b) a video mention of an relation/event was asserted using KeyFrameVideoJustification
        # Updating the following for Phase 3 is unnecessary since there are no videos in the collection
        if entry.get('schema').get('name') in [
                'AIDA_PHASE2_TASK1_CM_RESPONSE',
                'AIDA_PHASE2_TASK2_ZH_RESPONSE'
        ] and modality == 'video':
            if keyframe_id and entry.get('metatype') != 'Entity':
                self.record_event('UNEXPECTED_JUSTIFICATION', provenance,
                                  entry.get('metatype'),
                                  entry.get('cluster_id'),
                                  'KeyFrameVideoJustification',
                                  entry.get('where'))
                return False
            elif not keyframe_id and entry.get('metatype') not in [
                    'Relation', 'Event'
            ]:
                self.record_event('UNEXPECTED_JUSTIFICATION', provenance,
                                  entry.get('metatype'),
                                  entry.get('cluster_id'),
                                  'VideoJustification', entry.get('where'))
                return False

        document_element_boundary = responses.get('{}_boundaries'.format(
            'keyframe' if modality == 'video' and keyframe_id else modality
        )).get(keyframe_id
               if modality == 'video' and keyframe_id else document_element_id)
        span = Span(self.logger, start_x, start_y, end_x, end_y)
        if not document_element_boundary.validate(span):
            corrected_span = document_element_boundary.get(
                'corrected_span', span)
            if corrected_span is None or not apply_correction:
                self.record_event('SPAN_OFF_BOUNDARY_ERROR', span,
                                  document_element_boundary,
                                  document_element_id, where)
                return False
            corrected_provenance = '{}:{}:{}'.format(
                document_id,
                keyframe_id if keyframe_id else document_element_id,
                corrected_span.__str__())
            entry.set(attribute_name, corrected_provenance)
            self.record_event('SPAN_OFF_BOUNDARY_CORRECTED', span,
                              corrected_span, document_element_boundary,
                              document_element_id, where)
        return True