コード例 #1
0
    def _rewrite_rel(
            self,
            knp_lines: List[str],
            example: PasExample,
            arguments_set: List[List[int]],  # (max_seq_len, cases)
            document: Document,  # <格解析>付き
    ) -> List[str]:
        overts = self._extract_overt(document)

        output_knp_lines = []
        dtid = 0
        sent_idx = 0
        for line in knp_lines:
            if not line.startswith('+ '):
                output_knp_lines.append(line)
                if line == 'EOS':
                    sent_idx += 1
                continue

            assert '<rel ' not in line
            match = self.TAG_PAT.match(line)
            if match is not None:
                rel_string = self._rel_string(document.bp_list()[dtid],
                                              example, arguments_set, document,
                                              overts[dtid])
                rel_idx = match.end()
                output_knp_lines.append(line[:rel_idx] + rel_string +
                                        line[rel_idx:])
            else:
                self.logger.warning(f'invalid format line: {line}')
                output_knp_lines.append(line)

            dtid += 1

        return output_knp_lines
コード例 #2
0
    def _rel_string(
        self,
        bp: BasePhrase,
        example: PasExample,
        arguments_set: List[List[int]],  # (max_seq_len, cases)
        document: Document,
        overt_dict: Dict[str, int],
    ) -> str:
        rels: List[RelTag] = []
        dmid2bp = {
            document.mrph2dmid[mrph]: bp
            for bp in document.bp_list() for mrph in bp.mrph_list()
        }
        assert len(example.arguments_set) == len(dmid2bp)
        for mrph in bp.mrph_list():
            dmid: int = document.mrph2dmid[mrph]
            token_index: int = example.orig_to_tok_index[dmid]
            arguments: List[int] = arguments_set[token_index]
            # 助詞などの非解析対象形態素については gold_args が空になっている
            # inference時、解析対象形態素は ['NULL'] となる
            is_targets: Dict[str, bool] = {
                rel: bool(args)
                for rel, args in example.arguments_set[dmid].items()
            }
            assert len(self.relations) == len(arguments)
            for relation, argument in zip(self.relations, arguments):
                if not is_targets[relation]:
                    continue
                if self.use_knp_overt and relation in overt_dict:
                    # overt
                    prediction_dmid = overt_dict[relation]
                elif argument in self.index_to_special:
                    # special
                    special_arg = self.index_to_special[argument]
                    if special_arg in self.exophors:  # exclude [NULL] and [NA]
                        rels.append(RelTag(relation, special_arg, None, None))
                    continue
                else:
                    # normal
                    prediction_dmid = example.tok_to_orig_index[argument]
                    if prediction_dmid is None:
                        # [SEP] or [CLS]
                        self.logger.warning(
                            "Choose [SEP] as an argument. Tentatively, change it to NULL."
                        )
                        continue
                prediction_bp: BasePhrase = dmid2bp[prediction_dmid]
                rels.append(
                    RelTag(relation, prediction_bp.core, prediction_bp.sid,
                           prediction_bp.tid))

        return ''.join(rel.to_string() for rel in rels)
コード例 #3
0
ファイル: scorer.py プロジェクト: nobu-g/cohesion-analysis
    def __init__(self, document_pred: Document, document_gold: Document,
                 cases: List[str], bridging: bool, coreference: bool,
                 relax_exophors: Dict[str, str], pas_target: str):
        assert document_pred.doc_id == document_gold.doc_id
        self.doc_id: str = document_gold.doc_id
        self.document_pred: Document = document_pred
        self.document_gold: Document = document_gold
        self.cases: List[str] = cases
        self.pas: bool = pas_target != ''
        self.bridging: bool = bridging
        self.coreference: bool = coreference
        self.comp_result: Dict[tuple, str] = {}
        self.relax_exophors: Dict[str, str] = relax_exophors

        self.predicates_pred: List[BasePhrase] = []
        self.bridgings_pred: List[BasePhrase] = []
        self.mentions_pred: List[BasePhrase] = []
        for bp in document_pred.bp_list():
            if is_pas_target(bp,
                             verbal=(pas_target in ('pred', 'all')),
                             nominal=(pas_target in ('noun', 'all'))):
                self.predicates_pred.append(bp)
            if self.bridging and is_bridging_target(bp):
                self.bridgings_pred.append(bp)
            if self.coreference and is_coreference_target(bp):
                self.mentions_pred.append(bp)
        self.predicates_gold: List[BasePhrase] = []
        self.bridgings_gold: List[BasePhrase] = []
        self.mentions_gold: List[BasePhrase] = []
        for bp in document_gold.bp_list():
            if is_pas_target(bp,
                             verbal=(pas_target in ('pred', 'all')),
                             nominal=(pas_target in ('noun', 'all'))):
                self.predicates_gold.append(bp)
            if self.bridging and is_bridging_target(bp):
                self.bridgings_gold.append(bp)
            if self.coreference and is_coreference_target(bp):
                self.mentions_gold.append(bp)
コード例 #4
0
    def _evaluate_coref(self, doc_id: str, document_pred: Document,
                        document_gold: Document):
        dtid2mention_pred: Dict[int, Mention] = {
            bp.dtid: document_pred.mentions[bp.dtid]
            for bp in self.did2mentions_pred[doc_id]
            if bp.dtid in document_pred.mentions
        }
        dtid2mention_gold: Dict[int, Mention] = {
            bp.dtid: document_gold.mentions[bp.dtid]
            for bp in self.did2mentions_gold[doc_id]
            if bp.dtid in document_gold.mentions
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2mention_pred:
                src_mention_pred = dtid2mention_pred[dtid]
                tgt_mentions_pred = \
                    self._filter_mentions(document_pred.get_siblings(src_mention_pred), src_mention_pred)
                exophors_pred = [
                    e.exophor for e in map(document_pred.entities.get,
                                           src_mention_pred.eids)
                    if e.is_special
                ]
            else:
                tgt_mentions_pred = exophors_pred = []

            if dtid in dtid2mention_gold:
                src_mention_gold = dtid2mention_gold[dtid]
                tgt_mentions_gold = \
                    self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=False), src_mention_gold)
                tgt_mentions_gold_relaxed = \
                    self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=True), src_mention_gold)
                exophors_gold = [
                    e.exophor for e in map(document_gold.entities.get,
                                           src_mention_gold.eids) if
                    e.is_special and e.exophor in self.relax_exophors.values()
                ]
                exophors_gold_relaxed = [
                    e.exophor for e in map(document_gold.entities.get,
                                           src_mention_gold.all_eids) if
                    e.is_special and e.exophor in self.relax_exophors.values()
                ]
            else:
                tgt_mentions_gold = tgt_mentions_gold_relaxed = exophors_gold = exophors_gold_relaxed = []

            key = (doc_id, dtid, '=')

            # calculate precision
            if tgt_mentions_pred or exophors_pred:
                if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \
                        or (set(exophors_pred) & set(exophors_gold_relaxed)):
                    self.comp_result[key] = 'correct'
                    self.measure_coref.correct += 1
                else:
                    self.comp_result[key] = 'wrong'
                self.measure_coref.denom_pred += 1

            # calculate recall
            if tgt_mentions_gold or exophors_gold or (self.comp_result.get(
                    key, None) == 'correct'):
                if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \
                        or (set(exophors_pred) & set(exophors_gold_relaxed)):
                    assert self.comp_result[key] == 'correct'
                else:
                    self.comp_result[key] = 'wrong'
                self.measure_coref.denom_gold += 1
コード例 #5
0
    def _evaluate_bridging(self, doc_id: str, document_pred: Document,
                           document_gold: Document):
        dtid2anaphor_pred: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2bridgings_pred[doc_id]
        }
        dtid2anaphor_gold: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2bridgings_gold[doc_id]
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2anaphor_pred:
                anaphor_pred = dtid2anaphor_pred[dtid]
                antecedents_pred: List[BaseArgument] = \
                    self._filter_args(document_pred.get_arguments(anaphor_pred, relax=False)['ノ'], anaphor_pred)
            else:
                antecedents_pred = []
            assert len(antecedents_pred) in (
                0, 1
            )  # in bert_pas_analysis, predict one argument for one predicate

            if dtid in dtid2anaphor_gold:
                anaphor_gold: Predicate = dtid2anaphor_gold[dtid]
                antecedents_gold: List[BaseArgument] = \
                    self._filter_args(document_gold.get_arguments(anaphor_gold, relax=False)['ノ'], anaphor_gold)
                antecedents_gold_relaxed: List[BaseArgument] = \
                    self._filter_args(document_gold.get_arguments(anaphor_gold, relax=True)['ノ'] +
                                      document_gold.get_arguments(anaphor_gold, relax=True)['ノ?'], anaphor_gold)
            else:
                antecedents_gold = antecedents_gold_relaxed = []

            key = (doc_id, dtid, 'ノ')

            # calculate precision
            if antecedents_pred:
                antecedent_pred = antecedents_pred[0]
                analysis = Scorer.DEPTYPE2ANALYSIS[antecedent_pred.dep_type]
                if antecedent_pred in antecedents_gold_relaxed:
                    self.comp_result[key] = analysis
                    self.measures_bridging[analysis].correct += 1
                else:
                    self.comp_result[key] = 'wrong'
                self.measures_bridging[analysis].denom_pred += 1

            # calculate recall
            if antecedents_gold or (self.comp_result.get(key, None)
                                    in Scorer.DEPTYPE2ANALYSIS.values()):
                antecedent_gold = None
                for ant in antecedents_gold_relaxed:
                    if ant in antecedents_pred:
                        antecedent_gold = ant  # 予測されている先行詞を優先して正解の先行詞に採用
                        break
                if antecedent_gold is not None:
                    analysis = Scorer.DEPTYPE2ANALYSIS[
                        antecedent_gold.dep_type]
                    assert self.comp_result[key] == analysis
                else:
                    analysis = Scorer.DEPTYPE2ANALYSIS[
                        antecedents_gold[0].dep_type]
                    if antecedents_pred:
                        assert self.comp_result[key] == 'wrong'
                    else:
                        self.comp_result[key] = 'wrong'
                self.measures_bridging[analysis].denom_gold += 1
コード例 #6
0
    def _evaluate_pas(self, doc_id: str, document_pred: Document,
                      document_gold: Document):
        """calculate PAS analysis scores"""
        dtid2predicate_pred: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2predicates_pred[doc_id]
        }
        dtid2predicate_gold: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2predicates_gold[doc_id]
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2predicate_pred:
                predicate_pred = dtid2predicate_pred[dtid]
                arguments_pred = document_pred.get_arguments(predicate_pred,
                                                             relax=False)
            else:
                arguments_pred = None

            if dtid in dtid2predicate_gold:
                predicate_gold = dtid2predicate_gold[dtid]
                arguments_gold = document_gold.get_arguments(predicate_gold,
                                                             relax=False)
                arguments_gold_relaxed = document_gold.get_arguments(
                    predicate_gold, relax=True)
            else:
                predicate_gold = arguments_gold = arguments_gold_relaxed = None

            for case in self.cases:
                args_pred: List[BaseArgument] = arguments_pred[
                    case] if arguments_pred is not None else []
                assert len(args_pred) in (
                    0, 1
                )  # in bert_pas_analysis, predict one argument for one predicate
                if predicate_gold is not None:
                    args_gold = self._filter_args(arguments_gold[case],
                                                  predicate_gold)
                    args_gold_relaxed = self._filter_args(
                        arguments_gold_relaxed[case] +
                        (arguments_gold_relaxed['判ガ'] if case == 'ガ' else []),
                        predicate_gold)
                else:
                    args_gold = args_gold_relaxed = []

                key = (doc_id, dtid, case)

                # calculate precision
                if args_pred:
                    arg = args_pred[0]
                    analysis = Scorer.DEPTYPE2ANALYSIS[arg.dep_type]
                    if arg in args_gold_relaxed:
                        self.comp_result[key] = analysis
                        self.measures[case][analysis].correct += 1
                    else:
                        self.comp_result[key] = 'wrong'  # precision が下がる
                    self.measures[case][analysis].denom_pred += 1

                # calculate recall
                # 正解が複数ある場合、そのうち一つが当てられていればそれを正解に採用.
                # いずれも当てられていなければ、relax されていない項から一つを選び正解に採用.
                if args_gold or (self.comp_result.get(key, None)
                                 in Scorer.DEPTYPE2ANALYSIS.values()):
                    arg_gold = None
                    for arg in args_gold_relaxed:
                        if arg in args_pred:
                            arg_gold = arg  # 予測されている項を優先して正解の項に採用
                            break
                    if arg_gold is not None:
                        analysis = Scorer.DEPTYPE2ANALYSIS[arg_gold.dep_type]
                        assert self.comp_result[key] == analysis
                    else:
                        analysis = Scorer.DEPTYPE2ANALYSIS[
                            args_gold[0].dep_type]
                        if args_pred:
                            assert self.comp_result[key] == 'wrong'
                        else:
                            self.comp_result[key] = 'wrong'  # recall が下がる
                    self.measures[case][analysis].denom_gold += 1