예제 #1
0
 def is_predicate_justification_correct(self, system_predicate_justifications, gold_predicate_justifications):
     document_mappings = self.get('gold_responses').get('document_mappings')
     document_boundaries = self.get('gold_responses').get('document_boundaries')
     justification_correctness = False
     max_num_justifications = 2
     for system_predicate_justification in sorted(system_predicate_justifications.values(), key=lambda pj: pj.get('predicate_justification_confidence'), reverse=True):
         system_predicate_justification_span = system_predicate_justification.get('predicate_justification')
         system_mention_object = augment_mention_object(spanstring_to_object(self.logger, system_predicate_justification_span), document_mappings, document_boundaries)
         for gold_predicate_justification in gold_predicate_justifications.values():
             gold_predicate_justification_span = gold_predicate_justification.get('predicate_justification')
             gold_mention_object = augment_mention_object(spanstring_to_object(self.logger, gold_predicate_justification_span), document_mappings, document_boundaries)
             if get_intersection_over_union(system_mention_object, gold_mention_object) > 0:
                 justification_correctness = True
         max_num_justifications -= 1
         if max_num_justifications == 0: break
     return justification_correctness
예제 #2
0
파일: cluster.py 프로젝트: panx27/aida
 def add_mention(self, span_string, t_cv, cm_cv, j_cv, where):
     logger = self.get('logger')
     mention = augment_mention_object(spanstring_to_object(logger, span_string, where), self.get('document_mappings'), self.get('document_boundaries'))
     mention.set('ID', span_string)
     mention.set('span_string', span_string)
     mention.set('t_cv', t_cv)
     mention.set('cm_cv', cm_cv)
     mention.set('j_cv', j_cv)
     self.get('mentions').add(key=mention.get('ID'), value=mention)
예제 #3
0
 def contains_strict(self, mention, types, metatype):
     document_element_id = mention.get('document_element_id')
     keyframe_id = mention.get('keyframe_id')
     for cluster_type in types:
         key = '{docid}:{doce_or_kf_id}:{cluster_type}'.format(docid=mention.get('document_id'),
                                                        doce_or_kf_id=keyframe_id if keyframe_id else document_element_id,
                                                        cluster_type=cluster_type)
         if key not in self.get('regions'): continue
         for span_string in self.get('regions').get(key):
             fq_span_string = '{docid}:{doce_or_kf_id}:{span_string}'.format(docid=mention.get('document_id'),
                                                        doce_or_kf_id=keyframe_id if keyframe_id else document_element_id,
                                                        span_string=span_string)
             region = augment_mention_object(spanstring_to_object(self.logger, fq_span_string), self.get('document_mappings'), self.get('document_boundaries'))
             if get_intersection_over_union(mention, region) > 0:
                 return True
     return False
예제 #4
0
 def contains_relaxed(self, mention, types, metatype):
     document_id = mention.get('document_id')
     document_element_id = mention.get('document_element_id')
     keyframe_id = mention.get('keyframe_id')
     doce_or_kf_id = keyframe_id if keyframe_id else document_element_id
     for cluster_type in types:
         top_level_type = get_top_level_type(cluster_type, metatype)
         for key in self.get('regions'):
             document_id_, doce_or_kf_id_, cluster_type_ = key.split(':')
             metatype_ = self.get('ontology_type_mappings').get('type_to_metatype', cluster_type_)
             top_level_type_ = get_top_level_type(cluster_type_, metatype_)
             if document_id != document_id_: continue
             if doce_or_kf_id != doce_or_kf_id_: continue
             if top_level_type != top_level_type_: continue
             for region_string in [s.__str__() for s in self.get('regions').get(key)]:
                 fq_region_string = '{docid}:{doce_or_kf_id}:{region_string}'.format(docid=document_id_,
                                                        doce_or_kf_id=doce_or_kf_id_,
                                                        region_string=region_string)
                 region = augment_mention_object(spanstring_to_object(self.logger, fq_region_string), self.get('document_mappings'), self.get('document_boundaries'))
                 if get_intersection_over_union(mention, region) > 0:
                     return True
     return False