Esempio n. 1
0
    def load(
        self,
        document: Document,
        cases: List[str],
        exophors: List[str],
        coreference: bool,
        bridging: bool,
        relations: List[str],
        kc: bool,
        pas_targets: List[str],
        tokenizer: BertTokenizer,
    ) -> None:
        self.doc_id = document.doc_id
        process_all = (kc is False) or (document.doc_id.split('-')[-1] == '00')
        last_sent = document.sentences[-1] if len(document) > 0 else None
        relax_exophors = {}
        for exophor in exophors:
            relax_exophors[exophor] = exophor
            if exophor in ('不特定:人', '不特定:物', '不特定:状況'):
                for n in '123456789':
                    relax_exophors[exophor + n] = exophor
        dmid2arguments: Dict[int, Dict[str, List[BaseArgument]]] = {
            pred.dmid: document.get_arguments(pred)
            for pred in document.get_predicates()
        }
        head_dmids = []
        for sentence in document:
            process: bool = process_all or (sentence is last_sent)
            head_dmids += [bp.dmid for bp in sentence.bps]
            for bp in sentence.bps:
                for mrph in bp.mrph_list():
                    self.words.append(mrph.midasi)
                    self.dtids.append(bp.dtid)
                    self.ddeps.append(
                        bp.parent.dtid if bp.parent is not None else -1)
                    arguments = OrderedDict((rel, []) for rel in relations)
                    arg_candidates = ment_candidates = []
                    if document.mrph2dmid[mrph] == bp.dmid and process is True:
                        if is_pas_target(bp,
                                         verbal=('pred' in pas_targets),
                                         nominal=('noun' in pas_targets)):
                            arg_candidates = [
                                x for x in head_dmids if x != bp.dmid
                            ]
                            for case in cases:
                                dmid2args = {
                                    dmid: arguments[case]
                                    for dmid, arguments in
                                    dmid2arguments.items()
                                }
                                arguments[case] = self._get_args(
                                    bp.dmid, dmid2args, relax_exophors,
                                    arg_candidates)

                        if bridging and is_bridging_target(bp):
                            arg_candidates = [
                                x for x in head_dmids if x != bp.dmid
                            ]
                            dmid2args = {
                                dmid: arguments['ノ']
                                for dmid, arguments in dmid2arguments.items()
                            }
                            arguments['ノ'] = self._get_args(
                                bp.dmid, dmid2args, relax_exophors,
                                arg_candidates)

                        if coreference and is_coreference_target(bp):
                            ment_candidates = [
                                x for x in head_dmids if x < bp.dmid
                            ]  # do not solve cataphora
                            arguments['='] = self._get_mentions(
                                bp, document, relax_exophors, ment_candidates)

                    self.arguments_set.append(arguments)
                    self.arg_candidates_set.append(arg_candidates)
                    self.ment_candidates_set.append(ment_candidates)
Esempio n. 2
0
def coverage(doc: Document) -> RetValue:
    ret = RetValue()
    for predicate in doc.get_predicates():
        ex = Example(doc, predicate)
        arguments = doc.get_arguments(predicate)
        is_pred_gold = any(arguments[case] for case in PRED_CASES)
Esempio n. 3
0
def draw_tree(
    document: Document,
    sid: str,
    cases: List[str],
    bridging: bool = False,
    coreference: bool = False,
    fh: Optional[TextIO] = None,
    html: bool = False,
) -> None:
    """sid で指定された文の述語項構造・共参照関係をツリー形式で fh に書き出す

    Args:
        document (Document): sid が含まれる文書
        sid (str): 出力対象の文ID
        cases (List[str]): 表示対象の格
        bridging (bool): 橋渡し照応関係も表示するかどうか
        coreference (bool): 共参照関係も表示するかどうか
        fh (Optional[TextIO]): 出力ストリーム
        html (bool): html 形式で出力するかどうか

    """
    blist: BList = document.sid2sentence[sid].blist
    with io.StringIO() as string:
        blist.draw_tag_tree(fh=string, show_pos=False)
        tree_strings = string.getvalue().rstrip('\n').split('\n')
    assert len(tree_strings) == len(blist.tag_list())
    all_targets = [m.core for m in document.mentions.values()]
    tid2mention = {
        mention.tid: mention
        for mention in document.mentions.values() if mention.sid == sid
    }
    for bp in document[sid].bps:
        tree_strings[bp.tid] += '  '
        if is_pas_target(bp, verbal=True, nominal=True):
            arguments = document.get_arguments(bp)
            for case in cases:
                args: List[BaseArgument] = arguments.get(case, [])
                targets = set()
                for arg in args:
                    target = str(arg)
                    if all_targets.count(str(arg)) > 1 and isinstance(
                            arg, Argument):
                        target += str(arg.dtid)
                    targets.add(target)
                if html:
                    color = 'black' if targets else 'gray'
                    tree_strings[
                        bp.
                        tid] += f'<font color="{color}">{case}:{",".join(targets)}</font> '
                else:
                    tree_strings[bp.tid] += f'{case}:{",".join(targets)} '
        if bridging and is_bridging_target(bp):
            args = document.get_arguments(bp).get('ノ', [])
            targets = set()
            for arg in args:
                target = str(arg)
                if all_targets.count(str(arg)) > 1 and isinstance(
                        arg, Argument):
                    target += str(arg.dtid)
                targets.add(target)
            if html:
                color = 'black' if targets else 'gray'
                tree_strings[
                    bp.
                    tid] += f'<font color="{color}">ノ:{",".join(targets)}</font> '
            else:
                tree_strings[bp.tid] += f'ノ:{",".join(targets)} '
        if coreference and is_coreference_target(bp):
            if bp.tid in tid2mention:
                src_mention = tid2mention[bp.tid]
                tgt_mentions = [
                    tgt for tgt in document.get_siblings(src_mention)
                    if tgt.dtid < src_mention.dtid
                ]
                targets = set()
                for tgt_mention in tgt_mentions:
                    target = tgt_mention.core
                    if all_targets.count(tgt_mention.core) > 1:
                        target += str(tgt_mention.dtid)
                    targets.add(target)
                for eid in src_mention.eids:
                    entity = document.entities[eid]
                    if entity.is_special:
                        targets.add(entity.exophor)
            else:
                targets = set()
            if html:
                color = 'black' if targets else 'gray'
                tree_strings[
                    bp.
                    tid] += f'<font color="{color}">=:{",".join(targets)}</font>'
            else:
                tree_strings[bp.tid] += f'=:{",".join(targets)}'
    print('\n'.join(tree_strings), file=fh)
Esempio n. 4
0
    def _draw_tree(self,
                   sid: str,
                   predicates: List[BasePhrase],
                   mentions: List[BasePhrase],
                   anaphors: List[BasePhrase],
                   document: Document,
                   fh: Optional[TextIO] = None,
                   html: bool = True) -> None:
        """Write the predicate-argument structures, coreference relations, and bridging anaphora relations of the
        specified sentence in tree format.

        Args:
            sid (str): 出力対象の文ID
            predicates (List[BasePhrase]): documentに含まれる全ての述語
            mentions (List[BasePhrase]): documentに含まれる全てのメンション
            anaphors (List[BasePhrase]): documentに含まれる全ての橋渡し照応詞
            document (Document): 出力対象の文が含まれる文書
            fh (Optional[TextIO]): 出力ストリーム
            html (bool): HTML形式で出力するかどうか
        """
        result2color = {
            anal: 'blue'
            for anal in Scorer.DEPTYPE2ANALYSIS.values()
        }
        result2color.update({'overt': 'green', 'wrong': 'red', None: 'gray'})
        result2color_coref = {'correct': 'blue', 'wrong': 'red', None: 'gray'}
        blist: BList = document.sid2sentence[sid].blist
        with io.StringIO() as string:
            blist.draw_tag_tree(fh=string, show_pos=False)
            tree_strings = string.getvalue().rstrip('\n').split('\n')
        assert len(tree_strings) == len(blist.tag_list())
        all_targets = [m.core for m in document.mentions.values()]
        tid2predicate: Dict[int, BasePhrase] = {
            predicate.tid: predicate
            for predicate in predicates if predicate.sid == sid
        }
        tid2mention: Dict[int, BasePhrase] = {
            mention.tid: mention
            for mention in mentions if mention.sid == sid
        }
        tid2bridging: Dict[int, BasePhrase] = {
            anaphor.tid: anaphor
            for anaphor in anaphors if anaphor.sid == sid
        }
        for tid in range(len(tree_strings)):
            tree_strings[tid] += '  '
            if tid in tid2predicate:
                predicate = tid2predicate[tid]
                arguments = document.get_arguments(predicate)
                for case in self.cases:
                    args = arguments[case]
                    if case == 'ガ':
                        args += arguments['判ガ']
                    targets = set()
                    for arg in args:
                        target = str(arg)
                        if all_targets.count(str(arg)) > 1 and isinstance(
                                arg, Argument):
                            target += str(arg.dtid)
                        targets.add(target)
                    result = self.comp_result.get(
                        (document.doc_id, predicate.dtid, case), None)
                    if html:
                        tree_strings[
                            tid] += f'<font color="{result2color[result]}">{case}:{",".join(targets)}</font> '
                    else:
                        tree_strings[tid] += f'{case}:{",".join(targets)} '

            if self.bridging and tid in tid2bridging:
                anaphor = tid2bridging[tid]
                arguments = document.get_arguments(anaphor)
                args = arguments['ノ'] + arguments['ノ?']
                targets = set()
                for arg in args:
                    target = str(arg)
                    if all_targets.count(str(arg)) > 1 and isinstance(
                            arg, Argument):
                        target += str(arg.dtid)
                    targets.add(target)
                result = self.comp_result.get(
                    (document.doc_id, anaphor.dtid, 'ノ'), None)
                if html:
                    tree_strings[
                        tid] += f'<font color="{result2color[result]}">ノ:{",".join(targets)}</font> '
                else:
                    tree_strings[tid] += f'ノ:{",".join(targets)} '

            if self.coreference and tid in tid2mention:
                targets = set()
                src_dtid = tid2mention[tid].dtid
                if src_dtid in document.mentions:
                    src_mention = document.mentions[src_dtid]
                    tgt_mentions_relaxed = SubScorer.filter_mentions(
                        document.get_siblings(src_mention, relax=True),
                        src_mention)
                    for tgt_mention in tgt_mentions_relaxed:
                        target: str = tgt_mention.core
                        if all_targets.count(target) > 1:
                            target += str(tgt_mention.dtid)
                        targets.add(target)
                    for eid in src_mention.eids:
                        entity = document.entities[eid]
                        if entity.exophor in self.relax_exophors:
                            targets.add(entity.exophor)
                result = self.comp_result.get((document.doc_id, src_dtid, '='),
                                              None)
                if html:
                    tree_strings[
                        tid] += f'<font color="{result2color_coref[result]}">=:{",".join(targets)}</font>'
                else:
                    tree_strings[tid] += '=:' + ','.join(targets)

        print('\n'.join(tree_strings), file=fh)
Esempio n. 5
0
    def _evaluate_bridging(self, doc_id: str, document_pred: Document,
                           document_gold: Document):
        dtid2anaphor_pred: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2bridgings_pred[doc_id]
        }
        dtid2anaphor_gold: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2bridgings_gold[doc_id]
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2anaphor_pred:
                anaphor_pred = dtid2anaphor_pred[dtid]
                antecedents_pred: List[BaseArgument] = \
                    self._filter_args(document_pred.get_arguments(anaphor_pred, relax=False)['ノ'], anaphor_pred)
            else:
                antecedents_pred = []
            assert len(antecedents_pred) in (
                0, 1
            )  # in bert_pas_analysis, predict one argument for one predicate

            if dtid in dtid2anaphor_gold:
                anaphor_gold: Predicate = dtid2anaphor_gold[dtid]
                antecedents_gold: List[BaseArgument] = \
                    self._filter_args(document_gold.get_arguments(anaphor_gold, relax=False)['ノ'], anaphor_gold)
                antecedents_gold_relaxed: List[BaseArgument] = \
                    self._filter_args(document_gold.get_arguments(anaphor_gold, relax=True)['ノ'] +
                                      document_gold.get_arguments(anaphor_gold, relax=True)['ノ?'], anaphor_gold)
            else:
                antecedents_gold = antecedents_gold_relaxed = []

            key = (doc_id, dtid, 'ノ')

            # calculate precision
            if antecedents_pred:
                antecedent_pred = antecedents_pred[0]
                analysis = Scorer.DEPTYPE2ANALYSIS[antecedent_pred.dep_type]
                if antecedent_pred in antecedents_gold_relaxed:
                    self.comp_result[key] = analysis
                    self.measures_bridging[analysis].correct += 1
                else:
                    self.comp_result[key] = 'wrong'
                self.measures_bridging[analysis].denom_pred += 1

            # calculate recall
            if antecedents_gold or (self.comp_result.get(key, None)
                                    in Scorer.DEPTYPE2ANALYSIS.values()):
                antecedent_gold = None
                for ant in antecedents_gold_relaxed:
                    if ant in antecedents_pred:
                        antecedent_gold = ant  # 予測されている先行詞を優先して正解の先行詞に採用
                        break
                if antecedent_gold is not None:
                    analysis = Scorer.DEPTYPE2ANALYSIS[
                        antecedent_gold.dep_type]
                    assert self.comp_result[key] == analysis
                else:
                    analysis = Scorer.DEPTYPE2ANALYSIS[
                        antecedents_gold[0].dep_type]
                    if antecedents_pred:
                        assert self.comp_result[key] == 'wrong'
                    else:
                        self.comp_result[key] = 'wrong'
                self.measures_bridging[analysis].denom_gold += 1
Esempio n. 6
0
    def _evaluate_pas(self, doc_id: str, document_pred: Document,
                      document_gold: Document):
        """calculate PAS analysis scores"""
        dtid2predicate_pred: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2predicates_pred[doc_id]
        }
        dtid2predicate_gold: Dict[int, Predicate] = {
            pred.dtid: pred
            for pred in self.did2predicates_gold[doc_id]
        }

        for dtid in range(len(document_pred.bp_list())):
            if dtid in dtid2predicate_pred:
                predicate_pred = dtid2predicate_pred[dtid]
                arguments_pred = document_pred.get_arguments(predicate_pred,
                                                             relax=False)
            else:
                arguments_pred = None

            if dtid in dtid2predicate_gold:
                predicate_gold = dtid2predicate_gold[dtid]
                arguments_gold = document_gold.get_arguments(predicate_gold,
                                                             relax=False)
                arguments_gold_relaxed = document_gold.get_arguments(
                    predicate_gold, relax=True)
            else:
                predicate_gold = arguments_gold = arguments_gold_relaxed = None

            for case in self.cases:
                args_pred: List[BaseArgument] = arguments_pred[
                    case] if arguments_pred is not None else []
                assert len(args_pred) in (
                    0, 1
                )  # in bert_pas_analysis, predict one argument for one predicate
                if predicate_gold is not None:
                    args_gold = self._filter_args(arguments_gold[case],
                                                  predicate_gold)
                    args_gold_relaxed = self._filter_args(
                        arguments_gold_relaxed[case] +
                        (arguments_gold_relaxed['判ガ'] if case == 'ガ' else []),
                        predicate_gold)
                else:
                    args_gold = args_gold_relaxed = []

                key = (doc_id, dtid, case)

                # calculate precision
                if args_pred:
                    arg = args_pred[0]
                    analysis = Scorer.DEPTYPE2ANALYSIS[arg.dep_type]
                    if arg in args_gold_relaxed:
                        self.comp_result[key] = analysis
                        self.measures[case][analysis].correct += 1
                    else:
                        self.comp_result[key] = 'wrong'  # precision が下がる
                    self.measures[case][analysis].denom_pred += 1

                # calculate recall
                # 正解が複数ある場合、そのうち一つが当てられていればそれを正解に採用.
                # いずれも当てられていなければ、relax されていない項から一つを選び正解に採用.
                if args_gold or (self.comp_result.get(key, None)
                                 in Scorer.DEPTYPE2ANALYSIS.values()):
                    arg_gold = None
                    for arg in args_gold_relaxed:
                        if arg in args_pred:
                            arg_gold = arg  # 予測されている項を優先して正解の項に採用
                            break
                    if arg_gold is not None:
                        analysis = Scorer.DEPTYPE2ANALYSIS[arg_gold.dep_type]
                        assert self.comp_result[key] == analysis
                    else:
                        analysis = Scorer.DEPTYPE2ANALYSIS[
                            args_gold[0].dep_type]
                        if args_pred:
                            assert self.comp_result[key] == 'wrong'
                        else:
                            self.comp_result[key] = 'wrong'  # recall が下がる
                    self.measures[case][analysis].denom_gold += 1