def _get_mentions( self, bp: BasePhrase, document: Document, relax_exophors: Dict[str, str], candidates: List[int], ) -> List[str]: if bp.dtid in document.mentions: ment_strings: List[str] = [] src_mention = document.mentions[bp.dtid] tgt_mentions = document.get_siblings(src_mention, relax=False) exophors = [ document.entities[eid].exophor for eid in src_mention.eids if document.entities[eid].is_special ] for mention in tgt_mentions: if mention.dmid not in candidates: logger.debug( f'mention: {mention} in {self.doc_id} is not in candidates and ignored' ) continue ment_strings.append(str(mention.dmid)) for exophor in exophors: if exophor in relax_exophors: ment_strings.append( relax_exophors[exophor]) # 不特定:人1 -> 不特定:人 if ment_strings: return ment_strings else: return ['NA'] # force cataphor to point [NA] else: return ['NA']
def _rewrite_rel( self, knp_lines: List[str], example: PasExample, arguments_set: List[List[int]], # (max_seq_len, cases) document: Document, # <格解析>付き ) -> List[str]: overts = self._extract_overt(document) output_knp_lines = [] dtid = 0 sent_idx = 0 for line in knp_lines: if not line.startswith('+ '): output_knp_lines.append(line) if line == 'EOS': sent_idx += 1 continue assert '<rel ' not in line match = self.TAG_PAT.match(line) if match is not None: rel_string = self._rel_string(document.bp_list()[dtid], example, arguments_set, document, overts[dtid]) rel_idx = match.end() output_knp_lines.append(line[:rel_idx] + rel_string + line[rel_idx:]) else: self.logger.warning(f'invalid format line: {line}') output_knp_lines.append(line) dtid += 1 return output_knp_lines
def _rel_string( self, bp: BasePhrase, example: PasExample, arguments_set: List[List[int]], # (max_seq_len, cases) document: Document, overt_dict: Dict[str, int], ) -> str: rels: List[RelTag] = [] dmid2bp = { document.mrph2dmid[mrph]: bp for bp in document.bp_list() for mrph in bp.mrph_list() } assert len(example.arguments_set) == len(dmid2bp) for mrph in bp.mrph_list(): dmid: int = document.mrph2dmid[mrph] token_index: int = example.orig_to_tok_index[dmid] arguments: List[int] = arguments_set[token_index] # 助詞などの非解析対象形態素については gold_args が空になっている # inference時、解析対象形態素は ['NULL'] となる is_targets: Dict[str, bool] = { rel: bool(args) for rel, args in example.arguments_set[dmid].items() } assert len(self.relations) == len(arguments) for relation, argument in zip(self.relations, arguments): if not is_targets[relation]: continue if self.use_knp_overt and relation in overt_dict: # overt prediction_dmid = overt_dict[relation] elif argument in self.index_to_special: # special special_arg = self.index_to_special[argument] if special_arg in self.exophors: # exclude [NULL] and [NA] rels.append(RelTag(relation, special_arg, None, None)) continue else: # normal prediction_dmid = example.tok_to_orig_index[argument] if prediction_dmid is None: # [SEP] or [CLS] self.logger.warning( "Choose [SEP] as an argument. Tentatively, change it to NULL." ) continue prediction_bp: BasePhrase = dmid2bp[prediction_dmid] rels.append( RelTag(relation, prediction_bp.core, prediction_bp.sid, prediction_bp.tid)) return ''.join(rel.to_string() for rel in rels)
def __init__(self, document_pred: Document, document_gold: Document, cases: List[str], bridging: bool, coreference: bool, relax_exophors: Dict[str, str], pas_target: str): assert document_pred.doc_id == document_gold.doc_id self.doc_id: str = document_gold.doc_id self.document_pred: Document = document_pred self.document_gold: Document = document_gold self.cases: List[str] = cases self.pas: bool = pas_target != '' self.bridging: bool = bridging self.coreference: bool = coreference self.comp_result: Dict[tuple, str] = {} self.relax_exophors: Dict[str, str] = relax_exophors self.predicates_pred: List[BasePhrase] = [] self.bridgings_pred: List[BasePhrase] = [] self.mentions_pred: List[BasePhrase] = [] for bp in document_pred.bp_list(): if is_pas_target(bp, verbal=(pas_target in ('pred', 'all')), nominal=(pas_target in ('noun', 'all'))): self.predicates_pred.append(bp) if self.bridging and is_bridging_target(bp): self.bridgings_pred.append(bp) if self.coreference and is_coreference_target(bp): self.mentions_pred.append(bp) self.predicates_gold: List[BasePhrase] = [] self.bridgings_gold: List[BasePhrase] = [] self.mentions_gold: List[BasePhrase] = [] for bp in document_gold.bp_list(): if is_pas_target(bp, verbal=(pas_target in ('pred', 'all')), nominal=(pas_target in ('noun', 'all'))): self.predicates_gold.append(bp) if self.bridging and is_bridging_target(bp): self.bridgings_gold.append(bp) if self.coreference and is_coreference_target(bp): self.mentions_gold.append(bp)
def _pas_string( self, pas: Pas, cfid: str, document: Document, ) -> str: sid2index: Dict[str, int] = { sid: i for i, sid in enumerate(document.sid2sentence.keys()) } dtype2caseflag = { 'overt': 'C', 'dep': 'N', 'intra': 'O', 'inter': 'O', 'exo': 'E' } case_elements = [] for case in self.cases + (['ノ'] * self.bridging): items = ['-'] * 6 items[0] = case args = pas.arguments[case] if args: arg: BaseArgument = args[0] items[1] = dtype2caseflag[arg.dep_type] # フラグ (C/N/O/D/E/U) items[2] = str(arg) # 見出し if isinstance(arg, Argument): items[3] = str(sid2index[pas.sid] - sid2index[arg.sid]) # N文前 items[4] = str(arg.tid) # tag id items[5] = str( document.get_entities(arg)[0].eid) # Entity ID else: assert isinstance(arg, SpecialArgument) items[3] = str(-1) items[4] = str(-1) items[5] = str(arg.eid) # Entity ID else: items[1] = 'U' case_elements.append('/'.join(items)) return f"<述語項構造:{cfid}:{';'.join(case_elements)}>"
def _add_pas_analysis( self, knp_lines: List[str], document: Document, ) -> List[str]: dtid2pas = {pas.dtid: pas for pas in document.pas_list()} dtid = 0 output_knp_lines = [] for line in knp_lines: if not line.startswith('+ '): output_knp_lines.append(line) continue if dtid in dtid2pas: pas_string = self._pas_string(dtid2pas[dtid], 'dummy:dummy', document) output_knp_lines.append(line + pas_string) else: output_knp_lines.append(line) dtid += 1 return output_knp_lines
def load( self, document: Document, cases: List[str], exophors: List[str], coreference: bool, bridging: bool, relations: List[str], kc: bool, pas_targets: List[str], tokenizer: BertTokenizer, ) -> None: self.doc_id = document.doc_id process_all = (kc is False) or (document.doc_id.split('-')[-1] == '00') last_sent = document.sentences[-1] if len(document) > 0 else None relax_exophors = {} for exophor in exophors: relax_exophors[exophor] = exophor if exophor in ('不特定:人', '不特定:物', '不特定:状況'): for n in '123456789': relax_exophors[exophor + n] = exophor dmid2arguments: Dict[int, Dict[str, List[BaseArgument]]] = { pred.dmid: document.get_arguments(pred) for pred in document.get_predicates() } head_dmids = [] for sentence in document: process: bool = process_all or (sentence is last_sent) head_dmids += [bp.dmid for bp in sentence.bps] for bp in sentence.bps: for mrph in bp.mrph_list(): self.words.append(mrph.midasi) self.dtids.append(bp.dtid) self.ddeps.append( bp.parent.dtid if bp.parent is not None else -1) arguments = OrderedDict((rel, []) for rel in relations) arg_candidates = ment_candidates = [] if document.mrph2dmid[mrph] == bp.dmid and process is True: if is_pas_target(bp, verbal=('pred' in pas_targets), nominal=('noun' in pas_targets)): arg_candidates = [ x for x in head_dmids if x != bp.dmid ] for case in cases: dmid2args = { dmid: arguments[case] for dmid, arguments in dmid2arguments.items() } arguments[case] = self._get_args( bp.dmid, dmid2args, relax_exophors, arg_candidates) if bridging and is_bridging_target(bp): arg_candidates = [ x for x in head_dmids if x != bp.dmid ] dmid2args = { dmid: arguments['ノ'] for dmid, arguments in dmid2arguments.items() } arguments['ノ'] = self._get_args( bp.dmid, dmid2args, relax_exophors, arg_candidates) if coreference and is_coreference_target(bp): ment_candidates = [ x for x in head_dmids if x < bp.dmid ] # do not solve cataphora arguments['='] = self._get_mentions( bp, document, relax_exophors, ment_candidates) self.arguments_set.append(arguments) self.arg_candidates_set.append(arg_candidates) self.ment_candidates_set.append(ment_candidates)
def coverage(doc: Document) -> RetValue: ret = RetValue() for predicate in doc.get_predicates(): ex = Example(doc, predicate) arguments = doc.get_arguments(predicate) is_pred_gold = any(arguments[case] for case in PRED_CASES)
def write( self, arguments_sets: List[List[List[int]]], destination: Union[Path, TextIO, None], skip_untagged: bool = True, ) -> List[Document]: """Write final predictions to the file.""" if isinstance(destination, Path): self.logger.info(f'Writing predictions to: {destination}') destination.mkdir(exist_ok=True) elif not (destination is None or isinstance(destination, io.TextIOBase)): self.logger.warning('invalid output destination') did2examples = {ex.doc_id: ex for ex in self.examples} did2arguments_sets = { ex.doc_id: arguments_set for ex, arguments_set in zip(self.examples, arguments_sets) } did2knps: Dict[str, List[str]] = defaultdict(list) for document in self.documents: did = document.doc_id input_knp_lines = document.knp_string.strip().split('\n') if did in did2examples: output_knp_lines = self._rewrite_rel(input_knp_lines, did2examples[did], did2arguments_sets[did], document) else: if skip_untagged: continue output_knp_lines = [] for line in input_knp_lines: if line.startswith('+ '): line = self.REL_PAT.sub('', line) # remove gold data assert '<rel ' not in line output_knp_lines.append(line) knp_strings: List[str] = [] buff = '' for knp_line in output_knp_lines: buff += knp_line + '\n' if knp_line.strip() == 'EOS': knp_strings.append(buff) buff = '' if self.kc: orig_did, idx = did.split('-') if idx == '00': did2knps[orig_did] += knp_strings else: did2knps[orig_did].append(knp_strings[-1]) else: did2knps[did] = knp_strings documents_pred: List[Document] = [] # kc については元通り結合された文書のリスト for did, knp_strings in did2knps.items(): document_pred = Document(''.join(knp_strings), did, self.reader.target_cases, self.reader.target_corefs, self.reader.relax_cases, extract_nes=False, use_pas_tag=False) documents_pred.append(document_pred) if destination is None: continue output_knp_lines = self._add_pas_analysis( document_pred.knp_string.split('\n'), document_pred) output_string = '\n'.join(output_knp_lines) + '\n' if isinstance(destination, Path): output_basename = did + '.knp' with destination.joinpath(output_basename).open('w') as writer: writer.write(output_string) elif isinstance(destination, io.TextIOBase): destination.write(output_string) return documents_pred
def draw_tree( document: Document, sid: str, cases: List[str], bridging: bool = False, coreference: bool = False, fh: Optional[TextIO] = None, html: bool = False, ) -> None: """sid で指定された文の述語項構造・共参照関係をツリー形式で fh に書き出す Args: document (Document): sid が含まれる文書 sid (str): 出力対象の文ID cases (List[str]): 表示対象の格 bridging (bool): 橋渡し照応関係も表示するかどうか coreference (bool): 共参照関係も表示するかどうか fh (Optional[TextIO]): 出力ストリーム html (bool): html 形式で出力するかどうか """ blist: BList = document.sid2sentence[sid].blist with io.StringIO() as string: blist.draw_tag_tree(fh=string, show_pos=False) tree_strings = string.getvalue().rstrip('\n').split('\n') assert len(tree_strings) == len(blist.tag_list()) all_targets = [m.core for m in document.mentions.values()] tid2mention = { mention.tid: mention for mention in document.mentions.values() if mention.sid == sid } for bp in document[sid].bps: tree_strings[bp.tid] += ' ' if is_pas_target(bp, verbal=True, nominal=True): arguments = document.get_arguments(bp) for case in cases: args: List[BaseArgument] = arguments.get(case, []) targets = set() for arg in args: target = str(arg) if all_targets.count(str(arg)) > 1 and isinstance( arg, Argument): target += str(arg.dtid) targets.add(target) if html: color = 'black' if targets else 'gray' tree_strings[ bp. tid] += f'<font color="{color}">{case}:{",".join(targets)}</font> ' else: tree_strings[bp.tid] += f'{case}:{",".join(targets)} ' if bridging and is_bridging_target(bp): args = document.get_arguments(bp).get('ノ', []) targets = set() for arg in args: target = str(arg) if all_targets.count(str(arg)) > 1 and isinstance( arg, Argument): target += str(arg.dtid) targets.add(target) if html: color = 'black' if targets else 'gray' tree_strings[ bp. tid] += f'<font color="{color}">ノ:{",".join(targets)}</font> ' else: tree_strings[bp.tid] += f'ノ:{",".join(targets)} ' if coreference and is_coreference_target(bp): if bp.tid in tid2mention: src_mention = tid2mention[bp.tid] tgt_mentions = [ tgt for tgt in document.get_siblings(src_mention) if tgt.dtid < src_mention.dtid ] targets = set() for tgt_mention in tgt_mentions: target = tgt_mention.core if all_targets.count(tgt_mention.core) > 1: target += str(tgt_mention.dtid) targets.add(target) for eid in src_mention.eids: entity = document.entities[eid] if entity.is_special: targets.add(entity.exophor) else: targets = set() if html: color = 'black' if targets else 'gray' tree_strings[ bp. tid] += f'<font color="{color}">=:{",".join(targets)}</font>' else: tree_strings[bp.tid] += f'=:{",".join(targets)}' print('\n'.join(tree_strings), file=fh)
def _draw_tree(self, sid: str, predicates: List[BasePhrase], mentions: List[BasePhrase], anaphors: List[BasePhrase], document: Document, fh: Optional[TextIO] = None, html: bool = True) -> None: """Write the predicate-argument structures, coreference relations, and bridging anaphora relations of the specified sentence in tree format. Args: sid (str): 出力対象の文ID predicates (List[BasePhrase]): documentに含まれる全ての述語 mentions (List[BasePhrase]): documentに含まれる全てのメンション anaphors (List[BasePhrase]): documentに含まれる全ての橋渡し照応詞 document (Document): 出力対象の文が含まれる文書 fh (Optional[TextIO]): 出力ストリーム html (bool): HTML形式で出力するかどうか """ result2color = { anal: 'blue' for anal in Scorer.DEPTYPE2ANALYSIS.values() } result2color.update({'overt': 'green', 'wrong': 'red', None: 'gray'}) result2color_coref = {'correct': 'blue', 'wrong': 'red', None: 'gray'} blist: BList = document.sid2sentence[sid].blist with io.StringIO() as string: blist.draw_tag_tree(fh=string, show_pos=False) tree_strings = string.getvalue().rstrip('\n').split('\n') assert len(tree_strings) == len(blist.tag_list()) all_targets = [m.core for m in document.mentions.values()] tid2predicate: Dict[int, BasePhrase] = { predicate.tid: predicate for predicate in predicates if predicate.sid == sid } tid2mention: Dict[int, BasePhrase] = { mention.tid: mention for mention in mentions if mention.sid == sid } tid2bridging: Dict[int, BasePhrase] = { anaphor.tid: anaphor for anaphor in anaphors if anaphor.sid == sid } for tid in range(len(tree_strings)): tree_strings[tid] += ' ' if tid in tid2predicate: predicate = tid2predicate[tid] arguments = document.get_arguments(predicate) for case in self.cases: args = arguments[case] if case == 'ガ': args += arguments['判ガ'] targets = set() for arg in args: target = str(arg) if all_targets.count(str(arg)) > 1 and isinstance( arg, Argument): target += str(arg.dtid) targets.add(target) result = self.comp_result.get( (document.doc_id, predicate.dtid, case), None) if html: tree_strings[ tid] += f'<font color="{result2color[result]}">{case}:{",".join(targets)}</font> ' else: tree_strings[tid] += f'{case}:{",".join(targets)} ' if self.bridging and tid in tid2bridging: anaphor = tid2bridging[tid] arguments = document.get_arguments(anaphor) args = arguments['ノ'] + arguments['ノ?'] targets = set() for arg in args: target = str(arg) if all_targets.count(str(arg)) > 1 and isinstance( arg, Argument): target += str(arg.dtid) targets.add(target) result = self.comp_result.get( (document.doc_id, anaphor.dtid, 'ノ'), None) if html: tree_strings[ tid] += f'<font color="{result2color[result]}">ノ:{",".join(targets)}</font> ' else: tree_strings[tid] += f'ノ:{",".join(targets)} ' if self.coreference and tid in tid2mention: targets = set() src_dtid = tid2mention[tid].dtid if src_dtid in document.mentions: src_mention = document.mentions[src_dtid] tgt_mentions_relaxed = SubScorer.filter_mentions( document.get_siblings(src_mention, relax=True), src_mention) for tgt_mention in tgt_mentions_relaxed: target: str = tgt_mention.core if all_targets.count(target) > 1: target += str(tgt_mention.dtid) targets.add(target) for eid in src_mention.eids: entity = document.entities[eid] if entity.exophor in self.relax_exophors: targets.add(entity.exophor) result = self.comp_result.get((document.doc_id, src_dtid, '='), None) if html: tree_strings[ tid] += f'<font color="{result2color_coref[result]}">=:{",".join(targets)}</font>' else: tree_strings[tid] += '=:' + ','.join(targets) print('\n'.join(tree_strings), file=fh)
def write( self, arguments_sets: List[List[List[int]]], destination: Union[Path, TextIO, None], skip_untagged: bool = True, add_pas_tag: bool = True, ) -> List[Document]: """Write final predictions to the file. Args: arguments_sets (List[List[List[int]]]): モデル出力 destination (Union[Path, TextIO, None]): 解析済み文書の出力先 skip_untagged (bool): 解析に失敗した文書を出力しないかどうか (default: True) add_pas_tag (bool): 解析結果に<述語項構造 >タグを付与するかどうか (default: True) Returns: List[Document]: 解析済み文書 """ if isinstance(destination, Path): self.logger.info(f'Writing predictions to: {destination}') destination.mkdir(exist_ok=True) elif not (destination is None or isinstance(destination, io.TextIOBase)): self.logger.warning('invalid output destination') did2examples = {ex.doc_id: ex for ex in self.examples} did2arguments_sets = { ex.doc_id: arguments_set for ex, arguments_set in zip(self.examples, arguments_sets) } did2knps: Dict[str, List[str]] = defaultdict(list) for document in self.documents: did = document.doc_id input_knp_lines = document.knp_string.strip().split('\n') if did in did2examples: output_knp_lines = self._rewrite_rel( input_knp_lines, did2examples[did], did2arguments_sets[did], document) # overtを抽出するためこれはreparse後に格解析したものがいい else: if skip_untagged: continue assert all('<rel ' not in line for line in input_knp_lines) output_knp_lines = input_knp_lines knp_strings: List[str] = [] buff = '' for knp_line in output_knp_lines: buff += knp_line + '\n' if knp_line.strip() == 'EOS': knp_strings.append(buff) buff = '' if self.kc: # merge documents orig_did, idx = did.split('-') if idx == '00': did2knps[orig_did] += knp_strings else: did2knps[orig_did].append(knp_strings[-1]) else: did2knps[did] = knp_strings documents_pred: List[Document] = [] # kc については元通り結合された文書のリスト for did, knp_strings in did2knps.items(): document_pred = Document(''.join(knp_strings), did, self.reader.target_cases, self.reader.target_corefs, self.reader.relax_cases, extract_nes=False, use_pas_tag=False) documents_pred.append(document_pred) if destination is None: continue output_knp_lines = document_pred.knp_string.strip().split('\n') if add_pas_tag: output_knp_lines = self._add_pas_tag(output_knp_lines, document_pred) output_string = '\n'.join(output_knp_lines) + '\n' if isinstance(destination, Path): output_basename = did + '.knp' with destination.joinpath(output_basename).open('w') as writer: writer.write(output_string) elif isinstance(destination, io.TextIOBase): destination.write(output_string) return documents_pred
def _evaluate_coref(self, doc_id: str, document_pred: Document, document_gold: Document): dtid2mention_pred: Dict[int, Mention] = { bp.dtid: document_pred.mentions[bp.dtid] for bp in self.did2mentions_pred[doc_id] if bp.dtid in document_pred.mentions } dtid2mention_gold: Dict[int, Mention] = { bp.dtid: document_gold.mentions[bp.dtid] for bp in self.did2mentions_gold[doc_id] if bp.dtid in document_gold.mentions } for dtid in range(len(document_pred.bp_list())): if dtid in dtid2mention_pred: src_mention_pred = dtid2mention_pred[dtid] tgt_mentions_pred = \ self._filter_mentions(document_pred.get_siblings(src_mention_pred), src_mention_pred) exophors_pred = [ e.exophor for e in map(document_pred.entities.get, src_mention_pred.eids) if e.is_special ] else: tgt_mentions_pred = exophors_pred = [] if dtid in dtid2mention_gold: src_mention_gold = dtid2mention_gold[dtid] tgt_mentions_gold = \ self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=False), src_mention_gold) tgt_mentions_gold_relaxed = \ self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=True), src_mention_gold) exophors_gold = [ e.exophor for e in map(document_gold.entities.get, src_mention_gold.eids) if e.is_special and e.exophor in self.relax_exophors.values() ] exophors_gold_relaxed = [ e.exophor for e in map(document_gold.entities.get, src_mention_gold.all_eids) if e.is_special and e.exophor in self.relax_exophors.values() ] else: tgt_mentions_gold = tgt_mentions_gold_relaxed = exophors_gold = exophors_gold_relaxed = [] key = (doc_id, dtid, '=') # calculate precision if tgt_mentions_pred or exophors_pred: if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \ or (set(exophors_pred) & set(exophors_gold_relaxed)): self.comp_result[key] = 'correct' self.measure_coref.correct += 1 else: self.comp_result[key] = 'wrong' self.measure_coref.denom_pred += 1 # calculate recall if tgt_mentions_gold or exophors_gold or (self.comp_result.get( key, None) == 'correct'): if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \ or (set(exophors_pred) & set(exophors_gold_relaxed)): assert self.comp_result[key] == 'correct' else: self.comp_result[key] = 'wrong' self.measure_coref.denom_gold += 1
def _evaluate_bridging(self, doc_id: str, document_pred: Document, document_gold: Document): dtid2anaphor_pred: Dict[int, Predicate] = { pred.dtid: pred for pred in self.did2bridgings_pred[doc_id] } dtid2anaphor_gold: Dict[int, Predicate] = { pred.dtid: pred for pred in self.did2bridgings_gold[doc_id] } for dtid in range(len(document_pred.bp_list())): if dtid in dtid2anaphor_pred: anaphor_pred = dtid2anaphor_pred[dtid] antecedents_pred: List[BaseArgument] = \ self._filter_args(document_pred.get_arguments(anaphor_pred, relax=False)['ノ'], anaphor_pred) else: antecedents_pred = [] assert len(antecedents_pred) in ( 0, 1 ) # in bert_pas_analysis, predict one argument for one predicate if dtid in dtid2anaphor_gold: anaphor_gold: Predicate = dtid2anaphor_gold[dtid] antecedents_gold: List[BaseArgument] = \ self._filter_args(document_gold.get_arguments(anaphor_gold, relax=False)['ノ'], anaphor_gold) antecedents_gold_relaxed: List[BaseArgument] = \ self._filter_args(document_gold.get_arguments(anaphor_gold, relax=True)['ノ'] + document_gold.get_arguments(anaphor_gold, relax=True)['ノ?'], anaphor_gold) else: antecedents_gold = antecedents_gold_relaxed = [] key = (doc_id, dtid, 'ノ') # calculate precision if antecedents_pred: antecedent_pred = antecedents_pred[0] analysis = Scorer.DEPTYPE2ANALYSIS[antecedent_pred.dep_type] if antecedent_pred in antecedents_gold_relaxed: self.comp_result[key] = analysis self.measures_bridging[analysis].correct += 1 else: self.comp_result[key] = 'wrong' self.measures_bridging[analysis].denom_pred += 1 # calculate recall if antecedents_gold or (self.comp_result.get(key, None) in Scorer.DEPTYPE2ANALYSIS.values()): antecedent_gold = None for ant in antecedents_gold_relaxed: if ant in antecedents_pred: antecedent_gold = ant # 予測されている先行詞を優先して正解の先行詞に採用 break if antecedent_gold is not None: analysis = Scorer.DEPTYPE2ANALYSIS[ antecedent_gold.dep_type] assert self.comp_result[key] == analysis else: analysis = Scorer.DEPTYPE2ANALYSIS[ antecedents_gold[0].dep_type] if antecedents_pred: assert self.comp_result[key] == 'wrong' else: self.comp_result[key] = 'wrong' self.measures_bridging[analysis].denom_gold += 1
def _evaluate_pas(self, doc_id: str, document_pred: Document, document_gold: Document): """calculate PAS analysis scores""" dtid2predicate_pred: Dict[int, Predicate] = { pred.dtid: pred for pred in self.did2predicates_pred[doc_id] } dtid2predicate_gold: Dict[int, Predicate] = { pred.dtid: pred for pred in self.did2predicates_gold[doc_id] } for dtid in range(len(document_pred.bp_list())): if dtid in dtid2predicate_pred: predicate_pred = dtid2predicate_pred[dtid] arguments_pred = document_pred.get_arguments(predicate_pred, relax=False) else: arguments_pred = None if dtid in dtid2predicate_gold: predicate_gold = dtid2predicate_gold[dtid] arguments_gold = document_gold.get_arguments(predicate_gold, relax=False) arguments_gold_relaxed = document_gold.get_arguments( predicate_gold, relax=True) else: predicate_gold = arguments_gold = arguments_gold_relaxed = None for case in self.cases: args_pred: List[BaseArgument] = arguments_pred[ case] if arguments_pred is not None else [] assert len(args_pred) in ( 0, 1 ) # in bert_pas_analysis, predict one argument for one predicate if predicate_gold is not None: args_gold = self._filter_args(arguments_gold[case], predicate_gold) args_gold_relaxed = self._filter_args( arguments_gold_relaxed[case] + (arguments_gold_relaxed['判ガ'] if case == 'ガ' else []), predicate_gold) else: args_gold = args_gold_relaxed = [] key = (doc_id, dtid, case) # calculate precision if args_pred: arg = args_pred[0] analysis = Scorer.DEPTYPE2ANALYSIS[arg.dep_type] if arg in args_gold_relaxed: self.comp_result[key] = analysis self.measures[case][analysis].correct += 1 else: self.comp_result[key] = 'wrong' # precision が下がる self.measures[case][analysis].denom_pred += 1 # calculate recall # 正解が複数ある場合、そのうち一つが当てられていればそれを正解に採用. # いずれも当てられていなければ、relax されていない項から一つを選び正解に採用. if args_gold or (self.comp_result.get(key, None) in Scorer.DEPTYPE2ANALYSIS.values()): arg_gold = None for arg in args_gold_relaxed: if arg in args_pred: arg_gold = arg # 予測されている項を優先して正解の項に採用 break if arg_gold is not None: analysis = Scorer.DEPTYPE2ANALYSIS[arg_gold.dep_type] assert self.comp_result[key] == analysis else: analysis = Scorer.DEPTYPE2ANALYSIS[ args_gold[0].dep_type] if args_pred: assert self.comp_result[key] == 'wrong' else: self.comp_result[key] = 'wrong' # recall が下がる self.measures[case][analysis].denom_gold += 1