Example #1
0
 def _get_mentions(
     self,
     bp: BasePhrase,
     document: Document,
     relax_exophors: Dict[str, str],
     candidates: List[int],
 ) -> List[str]:
     if bp.dtid in document.mentions:
         ment_strings: List[str] = []
         src_mention = document.mentions[bp.dtid]
         tgt_mentions = document.get_siblings(src_mention, relax=False)
         exophors = [
             document.entities[eid].exophor for eid in src_mention.eids
             if document.entities[eid].is_special
         ]
         for mention in tgt_mentions:
             if mention.dmid not in candidates:
                 logger.debug(
                     f'mention: {mention} in {self.doc_id} is not in candidates and ignored'
                 )
                 continue
             ment_strings.append(str(mention.dmid))
         for exophor in exophors:
             if exophor in relax_exophors:
                 ment_strings.append(
                     relax_exophors[exophor])  # 不特定:人1 -> 不特定:人
         if ment_strings:
             return ment_strings
         else:
             return ['NA']  # force cataphor to point [NA]
     else:
         return ['NA']
    def _rewrite_rel(
            self,
            knp_lines: List[str],
            example: PasExample,
            arguments_set: List[List[int]],  # (max_seq_len, cases)
            document: Document,  # <格解析>付き
    ) -> List[str]:
        overts = self._extract_overt(document)

        output_knp_lines = []
        dtid = 0
        sent_idx = 0
        for line in knp_lines:
            if not line.startswith('+ '):
                output_knp_lines.append(line)
                if line == 'EOS':
                    sent_idx += 1
                continue

            assert '<rel ' not in line
            match = self.TAG_PAT.match(line)
            if match is not None:
                rel_string = self._rel_string(document.bp_list()[dtid],
                                              example, arguments_set, document,
                                              overts[dtid])
                rel_idx = match.end()
                output_knp_lines.append(line[:rel_idx] + rel_string +
                                        line[rel_idx:])
            else:
                self.logger.warning(f'invalid format line: {line}')
                output_knp_lines.append(line)

            dtid += 1

        return output_knp_lines
Example #3
0
    def _rel_string(
        self,
        bp: BasePhrase,
        example: PasExample,
        arguments_set: List[List[int]],  # (max_seq_len, cases)
        document: Document,
        overt_dict: Dict[str, int],
    ) -> str:
        rels: List[RelTag] = []
        dmid2bp = {
            document.mrph2dmid[mrph]: bp
            for bp in document.bp_list() for mrph in bp.mrph_list()
        }
        assert len(example.arguments_set) == len(dmid2bp)
        for mrph in bp.mrph_list():
            dmid: int = document.mrph2dmid[mrph]
            token_index: int = example.orig_to_tok_index[dmid]
            arguments: List[int] = arguments_set[token_index]
            # 助詞などの非解析対象形態素については gold_args が空になっている
            # inference時、解析対象形態素は ['NULL'] となる
            is_targets: Dict[str, bool] = {
                rel: bool(args)
                for rel, args in example.arguments_set[dmid].items()
            }
            assert len(self.relations) == len(arguments)
            for relation, argument in zip(self.relations, arguments):
                if not is_targets[relation]:
                    continue
                if self.use_knp_overt and relation in overt_dict:
                    # overt
                    prediction_dmid = overt_dict[relation]
                elif argument in self.index_to_special:
                    # special
                    special_arg = self.index_to_special[argument]
                    if special_arg in self.exophors:  # exclude [NULL] and [NA]
                        rels.append(RelTag(relation, special_arg, None, None))
                    continue
                else:
                    # normal
                    prediction_dmid = example.tok_to_orig_index[argument]
                    if prediction_dmid is None:
                        # [SEP] or [CLS]
                        self.logger.warning(
                            "Choose [SEP] as an argument. Tentatively, change it to NULL."
                        )
                        continue
                prediction_bp: BasePhrase = dmid2bp[prediction_dmid]
                rels.append(
                    RelTag(relation, prediction_bp.core, prediction_bp.sid,
                           prediction_bp.tid))

        return ''.join(rel.to_string() for rel in rels)
Example #4
0
    def __init__(self, document_pred: Document, document_gold: Document,
                 cases: List[str], bridging: bool, coreference: bool,
                 relax_exophors: Dict[str, str], pas_target: str):
        assert document_pred.doc_id == document_gold.doc_id
        self.doc_id: str = document_gold.doc_id
        self.document_pred: Document = document_pred
        self.document_gold: Document = document_gold
        self.cases: List[str] = cases
        self.pas: bool = pas_target != ''
        self.bridging: bool = bridging
        self.coreference: bool = coreference
        self.comp_result: Dict[tuple, str] = {}
        self.relax_exophors: Dict[str, str] = relax_exophors

        self.predicates_pred: List[BasePhrase] = []
        self.bridgings_pred: List[BasePhrase] = []
        self.mentions_pred: List[BasePhrase] = []
        for bp in document_pred.bp_list():
            if is_pas_target(bp,
                             verbal=(pas_target in ('pred', 'all')),
                             nominal=(pas_target in ('noun', 'all'))):
                self.predicates_pred.append(bp)
            if self.bridging and is_bridging_target(bp):
                self.bridgings_pred.append(bp)
            if self.coreference and is_coreference_target(bp):
                self.mentions_pred.append(bp)
        self.predicates_gold: List[BasePhrase] = []
        self.bridgings_gold: List[BasePhrase] = []
        self.mentions_gold: List[BasePhrase] = []
        for bp in document_gold.bp_list():
            if is_pas_target(bp,
                             verbal=(pas_target in ('pred', 'all')),
                             nominal=(pas_target in ('noun', 'all'))):
                self.predicates_gold.append(bp)
            if self.bridging and is_bridging_target(bp):
                self.bridgings_gold.append(bp)
            if self.coreference and is_coreference_target(bp):
                self.mentions_gold.append(bp)
Example #5
0
 def _pas_string(
     self,
     pas: Pas,
     cfid: str,
     document: Document,
 ) -> str:
     sid2index: Dict[str, int] = {
         sid: i
         for i, sid in enumerate(document.sid2sentence.keys())
     }
     dtype2caseflag = {
         'overt': 'C',
         'dep': 'N',
         'intra': 'O',
         'inter': 'O',
         'exo': 'E'
     }
     case_elements = []
     for case in self.cases + (['ノ'] * self.bridging):
         items = ['-'] * 6
         items[0] = case
         args = pas.arguments[case]
         if args:
             arg: BaseArgument = args[0]
             items[1] = dtype2caseflag[arg.dep_type]  # フラグ (C/N/O/D/E/U)
             items[2] = str(arg)  # 見出し
             if isinstance(arg, Argument):
                 items[3] = str(sid2index[pas.sid] -
                                sid2index[arg.sid])  # N文前
                 items[4] = str(arg.tid)  # tag id
                 items[5] = str(
                     document.get_entities(arg)[0].eid)  # Entity ID
             else:
                 assert isinstance(arg, SpecialArgument)
                 items[3] = str(-1)
                 items[4] = str(-1)
                 items[5] = str(arg.eid)  # Entity ID
         else:
             items[1] = 'U'
         case_elements.append('/'.join(items))
     return f"<述語項構造:{cfid}:{';'.join(case_elements)}>"
Example #6
0
    def _add_pas_analysis(
        self,
        knp_lines: List[str],
        document: Document,
    ) -> List[str]:
        dtid2pas = {pas.dtid: pas for pas in document.pas_list()}
        dtid = 0
        output_knp_lines = []
        for line in knp_lines:
            if not line.startswith('+ '):
                output_knp_lines.append(line)
                continue
            if dtid in dtid2pas:
                pas_string = self._pas_string(dtid2pas[dtid], 'dummy:dummy',
                                              document)
                output_knp_lines.append(line + pas_string)
            else:
                output_knp_lines.append(line)

            dtid += 1

        return output_knp_lines
Example #7
0
    def load(
        self,
        document: Document,
        cases: List[str],
        exophors: List[str],
        coreference: bool,
        bridging: bool,
        relations: List[str],
        kc: bool,
        pas_targets: List[str],
        tokenizer: BertTokenizer,
    ) -> None:
        self.doc_id = document.doc_id
        process_all = (kc is False) or (document.doc_id.split('-')[-1] == '00')
        last_sent = document.sentences[-1] if len(document) > 0 else None
        relax_exophors = {}
        for exophor in exophors:
            relax_exophors[exophor] = exophor
            if exophor in ('不特定:人', '不特定:物', '不特定:状況'):
                for n in '123456789':
                    relax_exophors[exophor + n] = exophor
        dmid2arguments: Dict[int, Dict[str, List[BaseArgument]]] = {
            pred.dmid: document.get_arguments(pred)
            for pred in document.get_predicates()
        }
        head_dmids = []
        for sentence in document:
            process: bool = process_all or (sentence is last_sent)
            head_dmids += [bp.dmid for bp in sentence.bps]
            for bp in sentence.bps:
                for mrph in bp.mrph_list():
                    self.words.append(mrph.midasi)
                    self.dtids.append(bp.dtid)
                    self.ddeps.append(
                        bp.parent.dtid if bp.parent is not None else -1)
                    arguments = OrderedDict((rel, []) for rel in relations)
                    arg_candidates = ment_candidates = []
                    if document.mrph2dmid[mrph] == bp.dmid and process is True:
                        if is_pas_target(bp,
                                         verbal=('pred' in pas_targets),
                                         nominal=('noun' in pas_targets)):
                            arg_candidates = [
                                x for x in head_dmids if x != bp.dmid
                            ]
                            for case in cases:
                                dmid2args = {
                                    dmid: arguments[case]
                                    for dmid, arguments in
                                    dmid2arguments.items()
                                }
                                arguments[case] = self._get_args(
                                    bp.dmid, dmid2args, relax_exophors,
                                    arg_candidates)

                        if bridging and is_bridging_target(bp):
                            arg_candidates = [
                                x for x in head_dmids if x != bp.dmid
                            ]
                            dmid2args = {
                                dmid: arguments['ノ']
                                for dmid, arguments in dmid2arguments.items()
                            }
                            arguments['ノ'] = self._get_args(
                                bp.dmid, dmid2args, relax_exophors,
                                arg_candidates)

                        if coreference and is_coreference_target(bp):
                            ment_candidates = [
                                x for x in head_dmids if x < bp.dmid
                            ]  # do not solve cataphora
                            arguments['='] = self._get_mentions(
                                bp, document, relax_exophors, ment_candidates)

                    self.arguments_set.append(arguments)
                    self.arg_candidates_set.append(arg_candidates)
                    self.ment_candidates_set.append(ment_candidates)
Example #8
0
def coverage(doc: Document) -> RetValue:
    ret = RetValue()
    for predicate in doc.get_predicates():
        ex = Example(doc, predicate)
        arguments = doc.get_arguments(predicate)
        is_pred_gold = any(arguments[case] for case in PRED_CASES)
Example #9
0
    def write(
        self,
        arguments_sets: List[List[List[int]]],
        destination: Union[Path, TextIO, None],
        skip_untagged: bool = True,
    ) -> List[Document]:
        """Write final predictions to the file."""

        if isinstance(destination, Path):
            self.logger.info(f'Writing predictions to: {destination}')
            destination.mkdir(exist_ok=True)
        elif not (destination is None
                  or isinstance(destination, io.TextIOBase)):
            self.logger.warning('invalid output destination')

        did2examples = {ex.doc_id: ex for ex in self.examples}
        did2arguments_sets = {
            ex.doc_id: arguments_set
            for ex, arguments_set in zip(self.examples, arguments_sets)
        }

        did2knps: Dict[str, List[str]] = defaultdict(list)
        for document in self.documents:
            did = document.doc_id
            input_knp_lines = document.knp_string.strip().split('\n')
            if did in did2examples:
                output_knp_lines = self._rewrite_rel(input_knp_lines,
                                                     did2examples[did],
                                                     did2arguments_sets[did],
                                                     document)
            else:
                if skip_untagged:
                    continue
                output_knp_lines = []
                for line in input_knp_lines:
                    if line.startswith('+ '):
                        line = self.REL_PAT.sub('', line)  # remove gold data
                        assert '<rel ' not in line
                    output_knp_lines.append(line)

            knp_strings: List[str] = []
            buff = ''
            for knp_line in output_knp_lines:
                buff += knp_line + '\n'
                if knp_line.strip() == 'EOS':
                    knp_strings.append(buff)
                    buff = ''
            if self.kc:
                orig_did, idx = did.split('-')
                if idx == '00':
                    did2knps[orig_did] += knp_strings
                else:
                    did2knps[orig_did].append(knp_strings[-1])
            else:
                did2knps[did] = knp_strings

        documents_pred: List[Document] = []  # kc については元通り結合された文書のリスト
        for did, knp_strings in did2knps.items():
            document_pred = Document(''.join(knp_strings),
                                     did,
                                     self.reader.target_cases,
                                     self.reader.target_corefs,
                                     self.reader.relax_cases,
                                     extract_nes=False,
                                     use_pas_tag=False)
            documents_pred.append(document_pred)
            if destination is None:
                continue
            output_knp_lines = self._add_pas_analysis(
                document_pred.knp_string.split('\n'), document_pred)
            output_string = '\n'.join(output_knp_lines) + '\n'
            if isinstance(destination, Path):
                output_basename = did + '.knp'
                with destination.joinpath(output_basename).open('w') as writer:
                    writer.write(output_string)
            elif isinstance(destination, io.TextIOBase):
                destination.write(output_string)

        return documents_pred
Example #10
0
def draw_tree(
    document: Document,
    sid: str,
    cases: List[str],
    bridging: bool = False,
    coreference: bool = False,
    fh: Optional[TextIO] = None,
    html: bool = False,
) -> None:
    """sid で指定された文の述語項構造・共参照関係をツリー形式で fh に書き出す

    Args:
        document (Document): sid が含まれる文書
        sid (str): 出力対象の文ID
        cases (List[str]): 表示対象の格
        bridging (bool): 橋渡し照応関係も表示するかどうか
        coreference (bool): 共参照関係も表示するかどうか
        fh (Optional[TextIO]): 出力ストリーム
        html (bool): html 形式で出力するかどうか

    """
    blist: BList = document.sid2sentence[sid].blist
    with io.StringIO() as string:
        blist.draw_tag_tree(fh=string, show_pos=False)
        tree_strings = string.getvalue().rstrip('\n').split('\n')
    assert len(tree_strings) == len(blist.tag_list())
    all_targets = [m.core for m in document.mentions.values()]
    tid2mention = {
        mention.tid: mention
        for mention in document.mentions.values() if mention.sid == sid
    }
    for bp in document[sid].bps:
        tree_strings[bp.tid] += '  '
        if is_pas_target(bp, verbal=True, nominal=True):
            arguments = document.get_arguments(bp)
            for case in cases:
                args: List[BaseArgument] = arguments.get(case, [])
                targets = set()
                for arg in args:
                    target = str(arg)
                    if all_targets.count(str(arg)) > 1 and isinstance(
                            arg, Argument):
                        target += str(arg.dtid)
                    targets.add(target)
                if html:
                    color = 'black' if targets else 'gray'
                    tree_strings[
                        bp.
                        tid] += f'<font color="{color}">{case}:{",".join(targets)}</font> '
                else:
                    tree_strings[bp.tid] += f'{case}:{",".join(targets)} '
        if bridging and is_bridging_target(bp):
            args = document.get_arguments(bp).get('ノ', [])
            targets = set()
            for arg in args:
                target = str(arg)
                if all_targets.count(str(arg)) > 1 and isinstance(
                        arg, Argument):
                    target += str(arg.dtid)
                targets.add(target)
            if html:
                color = 'black' if targets else 'gray'
                tree_strings[
                    bp.
                    tid] += f'<font color="{color}">ノ:{",".join(targets)}</font> '
            else:
                tree_strings[bp.tid] += f'ノ:{",".join(targets)} '
        if coreference and is_coreference_target(bp):
            if bp.tid in tid2mention:
                src_mention = tid2mention[bp.tid]
                tgt_mentions = [
                    tgt for tgt in document.get_siblings(src_mention)
                    if tgt.dtid < src_mention.dtid
                ]
                targets = set()
                for tgt_mention in tgt_mentions:
                    target = tgt_mention.core
                    if all_targets.count(tgt_mention.core) > 1:
                        target += str(tgt_mention.dtid)
                    targets.add(target)
                for eid in src_mention.eids:
                    entity = document.entities[eid]
                    if entity.is_special:
                        targets.add(entity.exophor)
            else:
                targets = set()
            if html:
                color = 'black' if targets else 'gray'
                tree_strings[
                    bp.
                    tid] += f'<font color="{color}">=:{",".join(targets)}</font>'
            else:
                tree_strings[bp.tid] += f'=:{",".join(targets)}'
    print('\n'.join(tree_strings), file=fh)
Example #11
0
    def _draw_tree(self,
                   sid: str,
                   predicates: List[BasePhrase],
                   mentions: List[BasePhrase],
                   anaphors: List[BasePhrase],
                   document: Document,
                   fh: Optional[TextIO] = None,
                   html: bool = True) -> None:
        """Write the predicate-argument structures, coreference relations, and bridging anaphora relations of the
        specified sentence in tree format.

        Args:
            sid (str): 出力対象の文ID
            predicates (List[BasePhrase]): documentに含まれる全ての述語
            mentions (List[BasePhrase]): documentに含まれる全てのメンション
            anaphors (List[BasePhrase]): documentに含まれる全ての橋渡し照応詞
            document (Document): 出力対象の文が含まれる文書
            fh (Optional[TextIO]): 出力ストリーム
            html (bool): HTML形式で出力するかどうか
        """
        result2color = {
            anal: 'blue'
            for anal in Scorer.DEPTYPE2ANALYSIS.values()
        }
        result2color.update({'overt': 'green', 'wrong': 'red', None: 'gray'})
        result2color_coref = {'correct': 'blue', 'wrong': 'red', None: 'gray'}
        blist: BList = document.sid2sentence[sid].blist
        with io.StringIO() as string:
            blist.draw_tag_tree(fh=string, show_pos=False)
            tree_strings = string.getvalue().rstrip('\n').split('\n')
        assert len(tree_strings) == len(blist.tag_list())
        all_targets = [m.core for m in document.mentions.values()]
        tid2predicate: Dict[int, BasePhrase] = {
            predicate.tid: predicate
            for predicate in predicates if predicate.sid == sid
        }
        tid2mention: Dict[int, BasePhrase] = {
            mention.tid: mention
            for mention in mentions if mention.sid == sid
        }
        tid2bridging: Dict[int, BasePhrase] = {
            anaphor.tid: anaphor
            for anaphor in anaphors if anaphor.sid == sid
        }
        for tid in range(len(tree_strings)):
            tree_strings[tid] += '  '
            if tid in tid2predicate:
                predicate = tid2predicate[tid]
                arguments = document.get_arguments(predicate)
                for case in self.cases:
                    args = arguments[case]
                    if case == 'ガ':
                        args += arguments['判ガ']
                    targets = set()
                    for arg in args:
                        target = str(arg)
                        if all_targets.count(str(arg)) > 1 and isinstance(
                                arg, Argument):
                            target += str(arg.dtid)
                        targets.add(target)
                    result = self.comp_result.get(
                        (document.doc_id, predicate.dtid, case), None)
                    if html:
                        tree_strings[
                            tid] += f'<font color="{result2color[result]}">{case}:{",".join(targets)}</font> '
                    else:
                        tree_strings[tid] += f'{case}:{",".join(targets)} '

            if self.bridging and tid in tid2bridging:
                anaphor = tid2bridging[tid]
                arguments = document.get_arguments(anaphor)
                args = arguments['ノ'] + arguments['ノ?']
                targets = set()
                for arg in args:
                    target = str(arg)
                    if all_targets.count(str(arg)) > 1 and isinstance(
                            arg, Argument):
                        target += str(arg.dtid)
                    targets.add(target)
                result = self.comp_result.get(
                    (document.doc_id, anaphor.dtid, 'ノ'), None)
                if html:
                    tree_strings[
                        tid] += f'<font color="{result2color[result]}">ノ:{",".join(targets)}</font> '
                else:
                    tree_strings[tid] += f'ノ:{",".join(targets)} '

            if self.coreference and tid in tid2mention:
                targets = set()
                src_dtid = tid2mention[tid].dtid
                if src_dtid in document.mentions:
                    src_mention = document.mentions[src_dtid]
                    tgt_mentions_relaxed = SubScorer.filter_mentions(
                        document.get_siblings(src_mention, relax=True),
                        src_mention)
                    for tgt_mention in tgt_mentions_relaxed:
                        target: str = tgt_mention.core
                        if all_targets.count(target) > 1:
                            target += str(tgt_mention.dtid)
                        targets.add(target)
                    for eid in src_mention.eids:
                        entity = document.entities[eid]
                        if entity.exophor in self.relax_exophors:
                            targets.add(entity.exophor)
                result = self.comp_result.get((document.doc_id, src_dtid, '='),
                                              None)
                if html:
                    tree_strings[
                        tid] += f'<font color="{result2color_coref[result]}">=:{",".join(targets)}</font>'
                else:
                    tree_strings[tid] += '=:' + ','.join(targets)

        print('\n'.join(tree_strings), file=fh)
    def write(
        self,
        arguments_sets: List[List[List[int]]],
        destination: Union[Path, TextIO, None],
        skip_untagged: bool = True,
        add_pas_tag: bool = True,
    ) -> List[Document]:
        """Write final predictions to the file.

        Args:
            arguments_sets (List[List[List[int]]]): モデル出力
            destination (Union[Path, TextIO, None]): 解析済み文書の出力先
            skip_untagged (bool): 解析に失敗した文書を出力しないかどうか (default: True)
            add_pas_tag (bool): 解析結果に<述語項構造 >タグを付与するかどうか (default: True)
        Returns:
            List[Document]: 解析済み文書
        """

        if isinstance(destination, Path):
            self.logger.info(f'Writing predictions to: {destination}')
            destination.mkdir(exist_ok=True)
        elif not (destination is None
                  or isinstance(destination, io.TextIOBase)):
            self.logger.warning('invalid output destination')

        did2examples = {ex.doc_id: ex for ex in self.examples}
        did2arguments_sets = {
            ex.doc_id: arguments_set
            for ex, arguments_set in zip(self.examples, arguments_sets)
        }

        did2knps: Dict[str, List[str]] = defaultdict(list)
        for document in self.documents:
            did = document.doc_id
            input_knp_lines = document.knp_string.strip().split('\n')
            if did in did2examples:
                output_knp_lines = self._rewrite_rel(
                    input_knp_lines, did2examples[did],
                    did2arguments_sets[did],
                    document)  # overtを抽出するためこれはreparse後に格解析したものがいい
            else:
                if skip_untagged:
                    continue
                assert all('<rel ' not in line for line in input_knp_lines)
                output_knp_lines = input_knp_lines

            knp_strings: List[str] = []
            buff = ''
            for knp_line in output_knp_lines:
                buff += knp_line + '\n'
                if knp_line.strip() == 'EOS':
                    knp_strings.append(buff)
                    buff = ''
            if self.kc:
                # merge documents
                orig_did, idx = did.split('-')
                if idx == '00':
                    did2knps[orig_did] += knp_strings
                else:
                    did2knps[orig_did].append(knp_strings[-1])
            else:
                did2knps[did] = knp_strings

        documents_pred: List[Document] = []  # kc については元通り結合された文書のリスト
        for did, knp_strings in did2knps.items():
            document_pred = Document(''.join(knp_strings),
                                     did,
                                     self.reader.target_cases,
                                     self.reader.target_corefs,
                                     self.reader.relax_cases,
                                     extract_nes=False,
                                     use_pas_tag=False)
            documents_pred.append(document_pred)
            if destination is None:
                continue
            output_knp_lines = document_pred.knp_string.strip().split('\n')
            if add_pas_tag:
                output_knp_lines = self._add_pas_tag(output_knp_lines,
                                                     document_pred)
            output_string = '\n'.join(output_knp_lines) + '\n'
            if isinstance(destination, Path):
                output_basename = did + '.knp'
                with destination.joinpath(output_basename).open('w') as writer:
                    writer.write(output_string)
            elif isinstance(destination, io.TextIOBase):
                destination.write(output_string)

        return documents_pred
Example #13
0
    def _evaluate_coref(self, doc_id: str, document_pred: Document,
                        document_gold: Document):
        dtid2mention_pred: Dict[int, Mention] = {
            bp.dtid: document_pred.mentions[bp.dtid]
            for bp in self.did2mentions_pred[doc_id]
            if bp.dtid in document_pred.mentions
        }
        dtid2mention_gold: Dict[int, Mention] = {
            bp.dtid: document_gold.mentions[bp.dtid]
            for bp in self.did2mentions_gold[doc_id]
            if bp.dtid in document_gold.mentions
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2mention_pred:
                src_mention_pred = dtid2mention_pred[dtid]
                tgt_mentions_pred = \
                    self._filter_mentions(document_pred.get_siblings(src_mention_pred), src_mention_pred)
                exophors_pred = [
                    e.exophor for e in map(document_pred.entities.get,
                                           src_mention_pred.eids)
                    if e.is_special
                ]
            else:
                tgt_mentions_pred = exophors_pred = []

            if dtid in dtid2mention_gold:
                src_mention_gold = dtid2mention_gold[dtid]
                tgt_mentions_gold = \
                    self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=False), src_mention_gold)
                tgt_mentions_gold_relaxed = \
                    self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=True), src_mention_gold)
                exophors_gold = [
                    e.exophor for e in map(document_gold.entities.get,
                                           src_mention_gold.eids) if
                    e.is_special and e.exophor in self.relax_exophors.values()
                ]
                exophors_gold_relaxed = [
                    e.exophor for e in map(document_gold.entities.get,
                                           src_mention_gold.all_eids) if
                    e.is_special and e.exophor in self.relax_exophors.values()
                ]
            else:
                tgt_mentions_gold = tgt_mentions_gold_relaxed = exophors_gold = exophors_gold_relaxed = []

            key = (doc_id, dtid, '=')

            # calculate precision
            if tgt_mentions_pred or exophors_pred:
                if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \
                        or (set(exophors_pred) & set(exophors_gold_relaxed)):
                    self.comp_result[key] = 'correct'
                    self.measure_coref.correct += 1
                else:
                    self.comp_result[key] = 'wrong'
                self.measure_coref.denom_pred += 1

            # calculate recall
            if tgt_mentions_gold or exophors_gold or (self.comp_result.get(
                    key, None) == 'correct'):
                if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \
                        or (set(exophors_pred) & set(exophors_gold_relaxed)):
                    assert self.comp_result[key] == 'correct'
                else:
                    self.comp_result[key] = 'wrong'
                self.measure_coref.denom_gold += 1
Example #14
0
    def _evaluate_bridging(self, doc_id: str, document_pred: Document,
                           document_gold: Document):
        dtid2anaphor_pred: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2bridgings_pred[doc_id]
        }
        dtid2anaphor_gold: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2bridgings_gold[doc_id]
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2anaphor_pred:
                anaphor_pred = dtid2anaphor_pred[dtid]
                antecedents_pred: List[BaseArgument] = \
                    self._filter_args(document_pred.get_arguments(anaphor_pred, relax=False)['ノ'], anaphor_pred)
            else:
                antecedents_pred = []
            assert len(antecedents_pred) in (
                0, 1
            )  # in bert_pas_analysis, predict one argument for one predicate

            if dtid in dtid2anaphor_gold:
                anaphor_gold: Predicate = dtid2anaphor_gold[dtid]
                antecedents_gold: List[BaseArgument] = \
                    self._filter_args(document_gold.get_arguments(anaphor_gold, relax=False)['ノ'], anaphor_gold)
                antecedents_gold_relaxed: List[BaseArgument] = \
                    self._filter_args(document_gold.get_arguments(anaphor_gold, relax=True)['ノ'] +
                                      document_gold.get_arguments(anaphor_gold, relax=True)['ノ?'], anaphor_gold)
            else:
                antecedents_gold = antecedents_gold_relaxed = []

            key = (doc_id, dtid, 'ノ')

            # calculate precision
            if antecedents_pred:
                antecedent_pred = antecedents_pred[0]
                analysis = Scorer.DEPTYPE2ANALYSIS[antecedent_pred.dep_type]
                if antecedent_pred in antecedents_gold_relaxed:
                    self.comp_result[key] = analysis
                    self.measures_bridging[analysis].correct += 1
                else:
                    self.comp_result[key] = 'wrong'
                self.measures_bridging[analysis].denom_pred += 1

            # calculate recall
            if antecedents_gold or (self.comp_result.get(key, None)
                                    in Scorer.DEPTYPE2ANALYSIS.values()):
                antecedent_gold = None
                for ant in antecedents_gold_relaxed:
                    if ant in antecedents_pred:
                        antecedent_gold = ant  # 予測されている先行詞を優先して正解の先行詞に採用
                        break
                if antecedent_gold is not None:
                    analysis = Scorer.DEPTYPE2ANALYSIS[
                        antecedent_gold.dep_type]
                    assert self.comp_result[key] == analysis
                else:
                    analysis = Scorer.DEPTYPE2ANALYSIS[
                        antecedents_gold[0].dep_type]
                    if antecedents_pred:
                        assert self.comp_result[key] == 'wrong'
                    else:
                        self.comp_result[key] = 'wrong'
                self.measures_bridging[analysis].denom_gold += 1
Example #15
0
    def _evaluate_pas(self, doc_id: str, document_pred: Document,
                      document_gold: Document):
        """calculate PAS analysis scores"""
        dtid2predicate_pred: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2predicates_pred[doc_id]
        }
        dtid2predicate_gold: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2predicates_gold[doc_id]
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2predicate_pred:
                predicate_pred = dtid2predicate_pred[dtid]
                arguments_pred = document_pred.get_arguments(predicate_pred,
                                                             relax=False)
            else:
                arguments_pred = None

            if dtid in dtid2predicate_gold:
                predicate_gold = dtid2predicate_gold[dtid]
                arguments_gold = document_gold.get_arguments(predicate_gold,
                                                             relax=False)
                arguments_gold_relaxed = document_gold.get_arguments(
                    predicate_gold, relax=True)
            else:
                predicate_gold = arguments_gold = arguments_gold_relaxed = None

            for case in self.cases:
                args_pred: List[BaseArgument] = arguments_pred[
                    case] if arguments_pred is not None else []
                assert len(args_pred) in (
                    0, 1
                )  # in bert_pas_analysis, predict one argument for one predicate
                if predicate_gold is not None:
                    args_gold = self._filter_args(arguments_gold[case],
                                                  predicate_gold)
                    args_gold_relaxed = self._filter_args(
                        arguments_gold_relaxed[case] +
                        (arguments_gold_relaxed['判ガ'] if case == 'ガ' else []),
                        predicate_gold)
                else:
                    args_gold = args_gold_relaxed = []

                key = (doc_id, dtid, case)

                # calculate precision
                if args_pred:
                    arg = args_pred[0]
                    analysis = Scorer.DEPTYPE2ANALYSIS[arg.dep_type]
                    if arg in args_gold_relaxed:
                        self.comp_result[key] = analysis
                        self.measures[case][analysis].correct += 1
                    else:
                        self.comp_result[key] = 'wrong'  # precision が下がる
                    self.measures[case][analysis].denom_pred += 1

                # calculate recall
                # 正解が複数ある場合、そのうち一つが当てられていればそれを正解に採用.
                # いずれも当てられていなければ、relax されていない項から一つを選び正解に採用.
                if args_gold or (self.comp_result.get(key, None)
                                 in Scorer.DEPTYPE2ANALYSIS.values()):
                    arg_gold = None
                    for arg in args_gold_relaxed:
                        if arg in args_pred:
                            arg_gold = arg  # 予測されている項を優先して正解の項に採用
                            break
                    if arg_gold is not None:
                        analysis = Scorer.DEPTYPE2ANALYSIS[arg_gold.dep_type]
                        assert self.comp_result[key] == analysis
                    else:
                        analysis = Scorer.DEPTYPE2ANALYSIS[
                            args_gold[0].dep_type]
                        if args_pred:
                            assert self.comp_result[key] == 'wrong'
                        else:
                            self.comp_result[key] = 'wrong'  # recall が下がる
                    self.measures[case][analysis].denom_gold += 1