def _get_mentions(
     self,
     bp: BasePhrase,
     document: Document,
     relax_exophors: Dict[str, str],
     candidates: List[int],
 ) -> List[str]:
     if bp.dtid in document.mentions:
         ment_strings: List[str] = []
         src_mention = document.mentions[bp.dtid]
         tgt_mentions = document.get_siblings(src_mention, relax=False)
         exophors = [
             document.entities[eid].exophor for eid in src_mention.eids
             if document.entities[eid].is_special
         ]
         for mention in tgt_mentions:
             if mention.dmid not in candidates:
                 logger.debug(
                     f'mention: {mention} in {self.doc_id} is not in candidates and ignored'
                 )
                 continue
             ment_strings.append(str(mention.dmid))
         for exophor in exophors:
             if exophor in relax_exophors:
                 ment_strings.append(
                     relax_exophors[exophor])  # 不特定:人1 -> 不特定:人
         if ment_strings:
             return ment_strings
         else:
             return ['NA']  # force cataphor to point [NA]
     else:
         return ['NA']
Beispiel #2
0
    def _draw_tree(self,
                   sid: str,
                   predicates: List[BasePhrase],
                   mentions: List[BasePhrase],
                   anaphors: List[BasePhrase],
                   document: Document,
                   fh: Optional[TextIO] = None,
                   html: bool = True) -> None:
        """Write the predicate-argument structures, coreference relations, and bridging anaphora relations of the
        specified sentence in tree format.

        Args:
            sid (str): 出力対象の文ID
            predicates (List[BasePhrase]): documentに含まれる全ての述語
            mentions (List[BasePhrase]): documentに含まれる全てのメンション
            anaphors (List[BasePhrase]): documentに含まれる全ての橋渡し照応詞
            document (Document): 出力対象の文が含まれる文書
            fh (Optional[TextIO]): 出力ストリーム
            html (bool): HTML形式で出力するかどうか
        """
        result2color = {
            anal: 'blue'
            for anal in Scorer.DEPTYPE2ANALYSIS.values()
        }
        result2color.update({'overt': 'green', 'wrong': 'red', None: 'gray'})
        result2color_coref = {'correct': 'blue', 'wrong': 'red', None: 'gray'}
        blist: BList = document.sid2sentence[sid].blist
        with io.StringIO() as string:
            blist.draw_tag_tree(fh=string, show_pos=False)
            tree_strings = string.getvalue().rstrip('\n').split('\n')
        assert len(tree_strings) == len(blist.tag_list())
        all_targets = [m.core for m in document.mentions.values()]
        tid2predicate: Dict[int, BasePhrase] = {
            predicate.tid: predicate
            for predicate in predicates if predicate.sid == sid
        }
        tid2mention: Dict[int, BasePhrase] = {
            mention.tid: mention
            for mention in mentions if mention.sid == sid
        }
        tid2bridging: Dict[int, BasePhrase] = {
            anaphor.tid: anaphor
            for anaphor in anaphors if anaphor.sid == sid
        }
        for tid in range(len(tree_strings)):
            tree_strings[tid] += '  '
            if tid in tid2predicate:
                predicate = tid2predicate[tid]
                arguments = document.get_arguments(predicate)
                for case in self.cases:
                    args = arguments[case]
                    if case == 'ガ':
                        args += arguments['判ガ']
                    targets = set()
                    for arg in args:
                        target = str(arg)
                        if all_targets.count(str(arg)) > 1 and isinstance(
                                arg, Argument):
                            target += str(arg.dtid)
                        targets.add(target)
                    result = self.comp_result.get(
                        (document.doc_id, predicate.dtid, case), None)
                    if html:
                        tree_strings[
                            tid] += f'<font color="{result2color[result]}">{case}:{",".join(targets)}</font> '
                    else:
                        tree_strings[tid] += f'{case}:{",".join(targets)} '

            if self.bridging and tid in tid2bridging:
                anaphor = tid2bridging[tid]
                arguments = document.get_arguments(anaphor)
                args = arguments['ノ'] + arguments['ノ?']
                targets = set()
                for arg in args:
                    target = str(arg)
                    if all_targets.count(str(arg)) > 1 and isinstance(
                            arg, Argument):
                        target += str(arg.dtid)
                    targets.add(target)
                result = self.comp_result.get(
                    (document.doc_id, anaphor.dtid, 'ノ'), None)
                if html:
                    tree_strings[
                        tid] += f'<font color="{result2color[result]}">ノ:{",".join(targets)}</font> '
                else:
                    tree_strings[tid] += f'ノ:{",".join(targets)} '

            if self.coreference and tid in tid2mention:
                targets = set()
                src_dtid = tid2mention[tid].dtid
                if src_dtid in document.mentions:
                    src_mention = document.mentions[src_dtid]
                    tgt_mentions_relaxed = SubScorer.filter_mentions(
                        document.get_siblings(src_mention, relax=True),
                        src_mention)
                    for tgt_mention in tgt_mentions_relaxed:
                        target: str = tgt_mention.core
                        if all_targets.count(target) > 1:
                            target += str(tgt_mention.dtid)
                        targets.add(target)
                    for eid in src_mention.eids:
                        entity = document.entities[eid]
                        if entity.exophor in self.relax_exophors:
                            targets.add(entity.exophor)
                result = self.comp_result.get((document.doc_id, src_dtid, '='),
                                              None)
                if html:
                    tree_strings[
                        tid] += f'<font color="{result2color_coref[result]}">=:{",".join(targets)}</font>'
                else:
                    tree_strings[tid] += '=:' + ','.join(targets)

        print('\n'.join(tree_strings), file=fh)
Beispiel #3
0
def draw_tree(
    document: Document,
    sid: str,
    cases: List[str],
    bridging: bool = False,
    coreference: bool = False,
    fh: Optional[TextIO] = None,
    html: bool = False,
) -> None:
    """sid で指定された文の述語項構造・共参照関係をツリー形式で fh に書き出す

    Args:
        document (Document): sid が含まれる文書
        sid (str): 出力対象の文ID
        cases (List[str]): 表示対象の格
        bridging (bool): 橋渡し照応関係も表示するかどうか
        coreference (bool): 共参照関係も表示するかどうか
        fh (Optional[TextIO]): 出力ストリーム
        html (bool): html 形式で出力するかどうか

    """
    blist: BList = document.sid2sentence[sid].blist
    with io.StringIO() as string:
        blist.draw_tag_tree(fh=string, show_pos=False)
        tree_strings = string.getvalue().rstrip('\n').split('\n')
    assert len(tree_strings) == len(blist.tag_list())
    all_targets = [m.core for m in document.mentions.values()]
    tid2mention = {
        mention.tid: mention
        for mention in document.mentions.values() if mention.sid == sid
    }
    for bp in document[sid].bps:
        tree_strings[bp.tid] += '  '
        if is_pas_target(bp, verbal=True, nominal=True):
            arguments = document.get_arguments(bp)
            for case in cases:
                args: List[BaseArgument] = arguments.get(case, [])
                targets = set()
                for arg in args:
                    target = str(arg)
                    if all_targets.count(str(arg)) > 1 and isinstance(
                            arg, Argument):
                        target += str(arg.dtid)
                    targets.add(target)
                if html:
                    color = 'black' if targets else 'gray'
                    tree_strings[
                        bp.
                        tid] += f'<font color="{color}">{case}:{",".join(targets)}</font> '
                else:
                    tree_strings[bp.tid] += f'{case}:{",".join(targets)} '
        if bridging and is_bridging_target(bp):
            args = document.get_arguments(bp).get('ノ', [])
            targets = set()
            for arg in args:
                target = str(arg)
                if all_targets.count(str(arg)) > 1 and isinstance(
                        arg, Argument):
                    target += str(arg.dtid)
                targets.add(target)
            if html:
                color = 'black' if targets else 'gray'
                tree_strings[
                    bp.
                    tid] += f'<font color="{color}">ノ:{",".join(targets)}</font> '
            else:
                tree_strings[bp.tid] += f'ノ:{",".join(targets)} '
        if coreference and is_coreference_target(bp):
            if bp.tid in tid2mention:
                src_mention = tid2mention[bp.tid]
                tgt_mentions = [
                    tgt for tgt in document.get_siblings(src_mention)
                    if tgt.dtid < src_mention.dtid
                ]
                targets = set()
                for tgt_mention in tgt_mentions:
                    target = tgt_mention.core
                    if all_targets.count(tgt_mention.core) > 1:
                        target += str(tgt_mention.dtid)
                    targets.add(target)
                for eid in src_mention.eids:
                    entity = document.entities[eid]
                    if entity.is_special:
                        targets.add(entity.exophor)
            else:
                targets = set()
            if html:
                color = 'black' if targets else 'gray'
                tree_strings[
                    bp.
                    tid] += f'<font color="{color}">=:{",".join(targets)}</font>'
            else:
                tree_strings[bp.tid] += f'=:{",".join(targets)}'
    print('\n'.join(tree_strings), file=fh)
Beispiel #4
0
    def _evaluate_coref(self, doc_id: str, document_pred: Document,
                        document_gold: Document):
        dtid2mention_pred: Dict[int, Mention] = {
            bp.dtid: document_pred.mentions[bp.dtid]
            for bp in self.did2mentions_pred[doc_id]
            if bp.dtid in document_pred.mentions
        }
        dtid2mention_gold: Dict[int, Mention] = {
            bp.dtid: document_gold.mentions[bp.dtid]
            for bp in self.did2mentions_gold[doc_id]
            if bp.dtid in document_gold.mentions
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2mention_pred:
                src_mention_pred = dtid2mention_pred[dtid]
                tgt_mentions_pred = \
                    self._filter_mentions(document_pred.get_siblings(src_mention_pred), src_mention_pred)
                exophors_pred = [
                    e.exophor for e in map(document_pred.entities.get,
                                           src_mention_pred.eids)
                    if e.is_special
                ]
            else:
                tgt_mentions_pred = exophors_pred = []

            if dtid in dtid2mention_gold:
                src_mention_gold = dtid2mention_gold[dtid]
                tgt_mentions_gold = \
                    self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=False), src_mention_gold)
                tgt_mentions_gold_relaxed = \
                    self._filter_mentions(document_gold.get_siblings(src_mention_gold, relax=True), src_mention_gold)
                exophors_gold = [
                    e.exophor for e in map(document_gold.entities.get,
                                           src_mention_gold.eids) if
                    e.is_special and e.exophor in self.relax_exophors.values()
                ]
                exophors_gold_relaxed = [
                    e.exophor for e in map(document_gold.entities.get,
                                           src_mention_gold.all_eids) if
                    e.is_special and e.exophor in self.relax_exophors.values()
                ]
            else:
                tgt_mentions_gold = tgt_mentions_gold_relaxed = exophors_gold = exophors_gold_relaxed = []

            key = (doc_id, dtid, '=')

            # calculate precision
            if tgt_mentions_pred or exophors_pred:
                if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \
                        or (set(exophors_pred) & set(exophors_gold_relaxed)):
                    self.comp_result[key] = 'correct'
                    self.measure_coref.correct += 1
                else:
                    self.comp_result[key] = 'wrong'
                self.measure_coref.denom_pred += 1

            # calculate recall
            if tgt_mentions_gold or exophors_gold or (self.comp_result.get(
                    key, None) == 'correct'):
                if (set(tgt_mentions_pred) & set(tgt_mentions_gold_relaxed)) \
                        or (set(exophors_pred) & set(exophors_gold_relaxed)):
                    assert self.comp_result[key] == 'correct'
                else:
                    self.comp_result[key] = 'wrong'
                self.measure_coref.denom_gold += 1