def _rewrite_rel( self, knp_lines: List[str], example: PasExample, arguments_set: List[List[int]], # (max_seq_len, cases) document: Document, # <格解析>付き ) -> List[str]: overts = self._extract_overt(document) output_knp_lines = [] dtid = 0 sent_idx = 0 for line in knp_lines: if not line.startswith('+ '): output_knp_lines.append(line) if line == 'EOS': sent_idx += 1 continue assert '<rel ' not in line match = self.TAG_PAT.match(line) if match is not None: rel_string = self._rel_string(document.bp_list()[dtid], example, arguments_set, document, overts[dtid]) rel_idx = match.end() output_knp_lines.append(line[:rel_idx] + rel_string + line[rel_idx:]) else: self.logger.warning(f'invalid format line: {line}') output_knp_lines.append(line) dtid += 1 return output_knp_lines
def _rel_string( self, bp: BasePhrase, example: PasExample, arguments_set: List[List[int]], # (max_seq_len, cases) document: Document, overt_dict: Dict[str, int], ) -> str: rels: List[RelTag] = [] dmid2bp = { document.mrph2dmid[mrph]: bp for bp in document.bp_list() for mrph in bp.mrph_list() } assert len(example.arguments_set) == len(dmid2bp) for mrph in bp.mrph_list(): dmid: int = document.mrph2dmid[mrph] token_index: int = example.orig_to_tok_index[dmid] arguments: List[int] = arguments_set[token_index] # 助詞などの非解析対象形態素については gold_args が空になっている # inference時、解析対象形態素は ['NULL'] となる is_targets: Dict[str, bool] = { rel: bool(args) for rel, args in example.arguments_set[dmid].items() } assert len(self.relations) == len(arguments) for relation, argument in zip(self.relations, arguments): if not is_targets[relation]: continue if self.use_knp_overt and relation in overt_dict: # overt prediction_dmid = overt_dict[relation] elif argument in self.index_to_special: # special special_arg = self.index_to_special[argument] if special_arg in self.exophors: # exclude [NULL] and [NA] rels.append(RelTag(relation, special_arg, None, None)) continue else: # normal prediction_dmid = example.tok_to_orig_index[argument] if prediction_dmid is None: # [SEP] or [CLS] self.logger.warning( "Choose [SEP] as an argument. Tentatively, change it to NULL." ) continue prediction_bp: BasePhrase = dmid2bp[prediction_dmid] rels.append( RelTag(relation, prediction_bp.core, prediction_bp.sid, prediction_bp.tid)) return ''.join(rel.to_string() for rel in rels)
def __init__(self, document_pred: Document, document_gold: Document, cases: List[str], bridging: bool, coreference: bool, relax_exophors: Dict[str, str], pas_target: str): assert document_pred.doc_id == document_gold.doc_id self.doc_id: str = document_gold.doc_id self.document_pred: Document = document_pred self.document_gold: Document = document_gold self.cases: List[str] = cases self.pas: bool = pas_target != '' self.bridging: bool = bridging self.coreference: bool = coreference self.comp_result: Dict[tuple, str] = {} self.relax_exophors: Dict[str, str] = relax_exophors self.predicates_pred: List[BasePhrase] = [] self.bridgings_pred: List[BasePhrase] = [] self.mentions_pred: List[BasePhrase] = [] for bp in document_pred.bp_list(): if is_pas_target(bp, verbal=(pas_target in ('pred', 'all')), nominal=(pas_target in ('noun', 'all'))): self.predicates_pred.append(bp) if self.bridging and is_bridging_target(bp): self.bridgings_pred.append(bp) if self.coreference and is_coreference_target(bp): self.mentions_pred.append(bp) self.predicates_gold: List[BasePhrase] = [] self.bridgings_gold: List[BasePhrase] = [] self.mentions_gold: List[BasePhrase] = [] for bp in document_gold.bp_list(): if is_pas_target(bp, verbal=(pas_target in ('pred', 'all')), nominal=(pas_target in ('noun', 'all'))): self.predicates_gold.append(bp) if self.bridging and is_bridging_target(bp): self.bridgings_gold.append(bp) if self.coreference and is_coreference_target(bp): self.mentions_gold.append(bp)
def _evaluate_coref(self, doc_id: str, document_pred: Document, document_gold: Document): dtid2mention_pred: Dict[int, Mention] = { bp.dtid: document_pred.mentions[bp.dtid] for bp in self.did2mentions_pred[doc_id] if bp.dtid in document_pred.mentions } dtid2mention_gold: Dict[int, Mention] = { bp.dtid: document_gold.mentions[bp.dtid] for bp in self.did2mentions_gold[doc_id] if bp.dtid in document_gold.mentions } for dtid in range(len(document_pred.bp_list())): if dtid in dtid2mention_pred: src_mention_pred = dtid2mention_pred[dtid] tgt_mentions_pred = \ self._filter_mentions(document_pred.get_siblings(src_mention_pred), src_mention_pred) exophors_pred = [ e.exophor for e in map(document_pred.entities.get, src_mention_pred.eids) if e.is_special ] else: tgt_mentions_pred = exophors_pred = [] if dtid in dtid2mention_gold: src_mention_gold = dtid2mention_gold[dtid] tgt_mentions_gold = \ self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=False), src_mention_gold) tgt_mentions_gold_relaxed = \ self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=True), src_mention_gold) exophors_gold = [ e.exophor for e in map(document_gold.entities.get, src_mention_gold.eids) if e.is_special and e.exophor in self.relax_exophors.values() ] exophors_gold_relaxed = [ e.exophor for e in map(document_gold.entities.get, src_mention_gold.all_eids) if e.is_special and e.exophor in self.relax_exophors.values() ] else: tgt_mentions_gold = tgt_mentions_gold_relaxed = exophors_gold = exophors_gold_relaxed = [] key = (doc_id, dtid, '=') # calculate precision if tgt_mentions_pred or exophors_pred: if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \ or (set(exophors_pred) & set(exophors_gold_relaxed)): self.comp_result[key] = 'correct' self.measure_coref.correct += 1 else: self.comp_result[key] = 'wrong' self.measure_coref.denom_pred += 1 # calculate recall if tgt_mentions_gold or exophors_gold or (self.comp_result.get( key, None) == 'correct'): if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \ or (set(exophors_pred) & set(exophors_gold_relaxed)): assert self.comp_result[key] == 'correct' else: self.comp_result[key] = 'wrong' self.measure_coref.denom_gold += 1
def _evaluate_bridging(self, doc_id: str, document_pred: Document, document_gold: Document): dtid2anaphor_pred: Dict[int, Predicate] = { pred.dtid: pred for pred in self.did2bridgings_pred[doc_id] } dtid2anaphor_gold: Dict[int, Predicate] = { pred.dtid: pred for pred in self.did2bridgings_gold[doc_id] } for dtid in range(len(document_pred.bp_list())): if dtid in dtid2anaphor_pred: anaphor_pred = dtid2anaphor_pred[dtid] antecedents_pred: List[BaseArgument] = \ self._filter_args(document_pred.get_arguments(anaphor_pred, relax=False)['ノ'], anaphor_pred) else: antecedents_pred = [] assert len(antecedents_pred) in ( 0, 1 ) # in bert_pas_analysis, predict one argument for one predicate if dtid in dtid2anaphor_gold: anaphor_gold: Predicate = dtid2anaphor_gold[dtid] antecedents_gold: List[BaseArgument] = \ self._filter_args(document_gold.get_arguments(anaphor_gold, relax=False)['ノ'], anaphor_gold) antecedents_gold_relaxed: List[BaseArgument] = \ self._filter_args(document_gold.get_arguments(anaphor_gold, relax=True)['ノ'] + document_gold.get_arguments(anaphor_gold, relax=True)['ノ?'], anaphor_gold) else: antecedents_gold = antecedents_gold_relaxed = [] key = (doc_id, dtid, 'ノ') # calculate precision if antecedents_pred: antecedent_pred = antecedents_pred[0] analysis = Scorer.DEPTYPE2ANALYSIS[antecedent_pred.dep_type] if antecedent_pred in antecedents_gold_relaxed: self.comp_result[key] = analysis self.measures_bridging[analysis].correct += 1 else: self.comp_result[key] = 'wrong' self.measures_bridging[analysis].denom_pred += 1 # calculate recall if antecedents_gold or (self.comp_result.get(key, None) in Scorer.DEPTYPE2ANALYSIS.values()): antecedent_gold = None for ant in antecedents_gold_relaxed: if ant in antecedents_pred: antecedent_gold = ant # 予測されている先行詞を優先して正解の先行詞に採用 break if antecedent_gold is not None: analysis = Scorer.DEPTYPE2ANALYSIS[ antecedent_gold.dep_type] assert self.comp_result[key] == analysis else: analysis = Scorer.DEPTYPE2ANALYSIS[ antecedents_gold[0].dep_type] if antecedents_pred: assert self.comp_result[key] == 'wrong' else: self.comp_result[key] = 'wrong' self.measures_bridging[analysis].denom_gold += 1
def _evaluate_pas(self, doc_id: str, document_pred: Document, document_gold: Document): """calculate PAS analysis scores""" dtid2predicate_pred: Dict[int, Predicate] = { pred.dtid: pred for pred in self.did2predicates_pred[doc_id] } dtid2predicate_gold: Dict[int, Predicate] = { pred.dtid: pred for pred in self.did2predicates_gold[doc_id] } for dtid in range(len(document_pred.bp_list())): if dtid in dtid2predicate_pred: predicate_pred = dtid2predicate_pred[dtid] arguments_pred = document_pred.get_arguments(predicate_pred, relax=False) else: arguments_pred = None if dtid in dtid2predicate_gold: predicate_gold = dtid2predicate_gold[dtid] arguments_gold = document_gold.get_arguments(predicate_gold, relax=False) arguments_gold_relaxed = document_gold.get_arguments( predicate_gold, relax=True) else: predicate_gold = arguments_gold = arguments_gold_relaxed = None for case in self.cases: args_pred: List[BaseArgument] = arguments_pred[ case] if arguments_pred is not None else [] assert len(args_pred) in ( 0, 1 ) # in bert_pas_analysis, predict one argument for one predicate if predicate_gold is not None: args_gold = self._filter_args(arguments_gold[case], predicate_gold) args_gold_relaxed = self._filter_args( arguments_gold_relaxed[case] + (arguments_gold_relaxed['判ガ'] if case == 'ガ' else []), predicate_gold) else: args_gold = args_gold_relaxed = [] key = (doc_id, dtid, case) # calculate precision if args_pred: arg = args_pred[0] analysis = Scorer.DEPTYPE2ANALYSIS[arg.dep_type] if arg in args_gold_relaxed: self.comp_result[key] = analysis self.measures[case][analysis].correct += 1 else: self.comp_result[key] = 'wrong' # precision が下がる self.measures[case][analysis].denom_pred += 1 # calculate recall # 正解が複数ある場合、そのうち一つが当てられていればそれを正解に採用. # いずれも当てられていなければ、relax されていない項から一つを選び正解に採用. if args_gold or (self.comp_result.get(key, None) in Scorer.DEPTYPE2ANALYSIS.values()): arg_gold = None for arg in args_gold_relaxed: if arg in args_pred: arg_gold = arg # 予測されている項を優先して正解の項に採用 break if arg_gold is not None: analysis = Scorer.DEPTYPE2ANALYSIS[arg_gold.dep_type] assert self.comp_result[key] == analysis else: analysis = Scorer.DEPTYPE2ANALYSIS[ args_gold[0].dep_type] if args_pred: assert self.comp_result[key] == 'wrong' else: self.comp_result[key] = 'wrong' # recall が下がる self.measures[case][analysis].denom_gold += 1