예제 #1
0
def process(input_path: Path, output_path: Path, corpus: str) -> int:
    output_path.mkdir(exist_ok=True)
    reader = KyotoReader(input_path, extract_nes=False)
    for document in tqdm(reader.process_all_documents(backend="multiprocessing"), desc=corpus, total=len(reader)):
        with output_path.joinpath(document.doc_id + '.pkl').open(mode='wb') as f:
            cPickle.dump(document, f)
    return len(reader)
예제 #2
0
    def __init__(
        self,
        path: Union[str, Path],
        cases: List[str],
        exophors: List[str],
        coreference: bool,
        bridging: bool,
        max_seq_length: int,
        bert_path: Union[str, Path],
        training: bool,
        kc: bool,
        train_targets: List[str],
        pas_targets: List[str],
        n_jobs: int = -1,
        logger=None,
        gold_path: Optional[str] = None,
    ) -> None:
        self.path = Path(path)
        self.reader = KyotoReader(self.path, extract_nes=False, n_jobs=n_jobs)
        self.target_cases: List[str] = [
            c for c in cases if c in self.reader.target_cases and c != 'ノ'
        ]
        self.target_exophors: List[str] = [
            e for e in exophors if e in ALL_EXOPHORS
        ]
        self.coreference: bool = coreference
        self.bridging: bool = bridging
        self.relations = self.target_cases + ['ノ'] * bridging + [
            '='
        ] * coreference
        self.kc: bool = kc
        self.train_targets: List[str] = [
            t if t != 'case' else 'dep' for t in train_targets
        ]  # backward compatibility
        self.pas_targets: List[str] = pas_targets
        self.logger: Logger = logger or logging.getLogger(__file__)
        special_tokens = self.target_exophors + ['NULL'] + (
            ['NA'] if coreference else [])
        self.special_to_index: Dict[str, int] = {
            token: max_seq_length - i - 1
            for i, token in enumerate(reversed(special_tokens))
        }
        self.tokenizer = BertTokenizer.from_pretrained(
            bert_path, do_lower_case=False, tokenize_chinese_chars=False)
        self.max_seq_length: int = max_seq_length
        self.bert_path: Path = Path(bert_path)
        documents = list(self.reader.process_all_documents())

        if not training:
            self.documents: Optional[List[Document]] = documents
            if gold_path is not None:
                reader = KyotoReader(Path(gold_path),
                                     extract_nes=False,
                                     n_jobs=n_jobs)
                self.gold_documents = list(reader.process_all_documents())

        self.examples = self._load(documents, str(path))
예제 #3
0
    def __init__(
        self,
        path: Union[str, Path],
        cases: List[str],
        exophors: List[str],
        coreference: bool,
        bridging: bool,
        max_seq_length: int,
        model_name: str,
        pretrained_path: Union[str, Path],
        training: bool,
        kc: bool,
        train_targets: List[str],
        pas_targets: List[str],
        logger=None,
        kc_joined_path: Optional[str] = None,
    ) -> None:
        self.reader = KyotoReader(Path(path), extract_nes=False)
        self.target_cases: List[str] = [
            c for c in cases if c in self.reader.target_cases and c != 'ノ'
        ]
        self.target_exophors: List[str] = [
            e for e in exophors if e in ALL_EXOPHORS
        ]
        self.coreference: bool = coreference
        self.bridging: bool = bridging
        self.kc: bool = kc
        self.train_targets: List[str] = train_targets
        self.pas_targets: List[str] = pas_targets
        self.logger: Logger = logger or logging.getLogger(__file__)
        special_tokens = self.target_exophors + ['NULL'] + (
            ['NA'] if coreference else [])
        self.special_to_index: Dict[str, int] = {
            token: max_seq_length - i - 1
            for i, token in enumerate(reversed(special_tokens))
        }
        tokenizer_cls = MODEL2TOKENIZER[model_name]
        self.tokenizer = tokenizer_cls.from_pretrained(
            pretrained_path, do_lower_case=False, tokenize_chinese_chars=False)
        self.max_seq_length: int = max_seq_length
        self.pretrained_path: Path = Path(pretrained_path)
        documents = list(
            self.reader.process_all_documents(backend="multiprocessing"))
        self.documents: Optional[
            List[Document]] = documents if not training else None

        if self.kc and not training:
            assert kc_joined_path is not None
            reader = KyotoReader(Path(kc_joined_path), extract_nes=False)
            self.joined_documents = list(
                reader.process_all_documents(backend="multiprocessing"))

        self.examples = self._load(documents, str(path))
예제 #4
0
def test_pas_relax(fixture_kyoto_reader: KyotoReader):
    document = fixture_kyoto_reader.process_document('w201106-0000060560')
    predicates: List[Predicate] = document.get_predicates()
    arguments = document.get_arguments(predicates[9], relax=True)
    sid1 = 'w201106-0000060560-1'
    sid2 = 'w201106-0000060560-2'
    sid3 = 'w201106-0000060560-3'
    assert predicates[9].core == 'ご協力'
    assert len([_ for args in arguments.values() for _ in args]) == 6
    args = sorted(arguments['ガ'], key=lambda a: a.dtid)
    arg = args[0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('ドクター', 7, 7, sid1, 'inter', 'AND')
    arg = args[1]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('ドクター', 2, 11, sid2, 'inter', 'AND')
    arg = args[2]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('ドクター', 0, 16, sid3, 'intra', 'AND')
    arg = args[3]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('皆様', 1, 17, sid3, 'intra', '')
    args = sorted(arguments['ニ'], key=lambda a: str(a))
    arg = args[0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('コーナー', 5, 14, sid2, 'inter', '?')
    arg = args[1]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('著者', 5, 'exo', '')
예제 #5
0
def test_dep_type(fixture_kyoto_reader: KyotoReader):
    document = fixture_kyoto_reader.process_document('w201106-0002000028')
    predicate: Predicate = document.get_predicates()[4]
    assert predicate.core == '同じ'
    arg = document.get_arguments(predicate, relax=True)['ガ'][1]
    assert isinstance(arg, Argument)
    assert arg.core == 'フランス'
    assert arg.dtid == 10
    assert arg.dep_type == 'dep'
예제 #6
0
def process(
    input_path: Path,
    output_path: Path,
    corpus: str,
    do_reparse: bool,
    n_jobs: int,
    bertknp: Optional[str],
    knp: KNP,
    keep_dep: bool = False,
    split: bool = False,
    max_subword_length: int = None,
    tokenizer: BertTokenizer = None,
) -> int:
    with tempfile.TemporaryDirectory(
    ) as tmp_dir1, tempfile.TemporaryDirectory() as tmp_dir2:
        tmp_dir1, tmp_dir2 = Path(tmp_dir1), Path(tmp_dir2)
        if do_reparse is True:
            reparse(input_path,
                    tmp_dir1,
                    knp,
                    bertknp=bertknp,
                    n_jobs=n_jobs,
                    keep_dep=keep_dep)
            input_path = tmp_dir1
        if split is True:
            # Because the length of the documents in KyotoCorpus is very long, split them into multiple documents
            # so that the tail sentence of each document has as much preceding sentences as possible.
            print('splitting corpus...')
            split_kc(input_path, tmp_dir2, max_subword_length, tokenizer)
            input_path = tmp_dir2

        output_path.mkdir(exist_ok=True)
        reader = KyotoReader(input_path,
                             extract_nes=False,
                             did_from_sid=(not split),
                             n_jobs=n_jobs)
        for document in tqdm(reader.process_all_documents(),
                             desc=corpus,
                             total=len(reader)):
            with output_path.joinpath(document.doc_id +
                                      '.pkl').open(mode='wb') as f:
                cPickle.dump(document, f)

    return len(reader)
예제 #7
0
def test_coref_link3(fixture_kyoto_reader: KyotoReader):
    document = fixture_kyoto_reader.process_document('w201106-0000060877')
    for entity in document.entities.values():
        for mention in entity.mentions:
            assert entity.eid in mention.eids
        for mention in entity.mentions_unc:
            assert entity.eid in mention.eids_unc
    for mention in document.mentions.values():
        for eid in mention.eids:
            assert mention in document.entities[eid].mentions
        for eid in mention.eids_unc:
            assert mention in document.entities[eid].mentions_unc
예제 #8
0
def main():
    reader = KyotoReader(sys.argv[1])
    ret = RetValue()
    for doc in reader.process_all_documents():
        ret += coverage(doc)
    print('pred:')
    print(
        f'  precision: {ret.measure_pred.precision:.4f} ({ret.measure_pred.correct}/{ret.measure_pred.denom_pred})'
    )
    print(
        f'  recall   : {ret.measure_pred.recall:.4f} ({ret.measure_pred.correct}/{ret.measure_pred.denom_gold})'
    )
    print(f'  F        : {ret.measure_pred.f1:.4f}')
    print('noun:')
    print(
        f'  precision: {ret.measure_noun.precision:.4f} ({ret.measure_noun.correct}/{ret.measure_noun.denom_pred})'
    )
    print(
        f'  recall   : {ret.measure_noun.recall:.4f} ({ret.measure_noun.correct}/{ret.measure_noun.denom_gold})'
    )
    print(f'  F        : {ret.measure_noun.f1:.4f}')
예제 #9
0
def process_kc(input_path: Path,
               output_path: Path,
               max_subword_length: int,
               tokenizer: TokenizeHandlerMeta,
               split: bool = False
               ) -> int:
    with tempfile.TemporaryDirectory() as tmp_dir:
        if split:
            tmp_dir = Path(tmp_dir)
            # 京大コーパスは1文書が長いのでできるだけ多くの context を含むように複数文書に分割する
            print('splitting kc...')
            split_kc(input_path, tmp_dir, max_subword_length, tokenizer)
            input_path = tmp_dir

        print(list(input_path.iterdir()))
        output_path.mkdir(exist_ok=True)
        reader = KyotoReader(input_path, extract_nes=False, did_from_sid=False)
        for document in tqdm(reader.process_all_documents(backend="multiprocessing"), desc='kc', total=len(reader)):
            with output_path.joinpath(document.doc_id + '.pkl').open(mode='wb') as f:
                cPickle.dump(document, f)

    return len(reader)
예제 #10
0
def reparse(
    input_dir: Path,
    output_dir: Path,
    knp: KNP,
    bertknp: Optional[str] = None,
    n_jobs: int = 0,
    keep_dep: bool = False,
) -> None:
    if bertknp is None:
        args_iter = ((path, output_dir, knp, keep_dep)
                     for path in input_dir.glob('*.knp'))
        if n_jobs > 0:
            with Pool(n_jobs) as pool:
                pool.starmap(reparse_knp, args_iter)
        else:
            for args in args_iter:
                reparse_knp(*args)
        return

    assert keep_dep is False, 'If you use BERTKNP, you cannot keep dependency labels.'

    buff = ''
    for knp_file in input_dir.glob('*.knp'):
        with knp_file.open() as fin:
            for line in fin:
                if line.startswith('+') or line.startswith('*'):
                    buff += line[0] + '\n'
                else:
                    buff += line

    out = subprocess.run([
        bertknp,
        '-p',
        Path(bertknp).parents[1] / '.venv/bin/python',
        '-O',
        Path(__file__).parent.joinpath(
            'bertknp_options.txt').resolve().__str__(),
        '-tab',
    ],
                         input=buff,
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE,
                         encoding='utf-8')
    logger.warning(out.stderr)
    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        tmp_dir.joinpath('tmp.knp').write_text(out.stdout)
        for did, knp_string in KyotoReader(tmp_dir /
                                           'tmp.knp').did2knps.items():
            output_dir.joinpath(f'{did}.knp').write_text(knp_string)
예제 #11
0
def test_ne(fixture_kyoto_reader: KyotoReader):
    document = fixture_kyoto_reader.process_document('w201106-0000060877')
    nes = document.named_entities
    assert len(nes) == 2
    ne = nes[0]
    assert (ne.category, ne.name, ne.dmid_range) == ('ORGANIZATION', '柏市ひまわり園',
                                                     range(5, 9))
    ne = nes[1]
    assert (ne.category, ne.name, ne.dmid_range) == ('DATE', '平成23年度',
                                                     range(11, 14))

    document = fixture_kyoto_reader.process_document('w201106-0000074273')
    nes = document.named_entities
    assert len(nes) == 3
    ne = nes[0]
    assert (ne.category, ne.name, ne.dmid_range) == ('LOCATION', 'ダーマ神殿',
                                                     range(15, 17))
    ne = nes[1]
    assert (ne.category, ne.name, ne.dmid_range) == ('ARTIFACT', '天の箱舟',
                                                     range(24, 27))
    ne = nes[2]
    assert (ne.category, ne.name, ne.dmid_range) == ('LOCATION', 'ナザム村',
                                                     range(39, 41))
예제 #12
0
def test_coref2(fixture_kyoto_reader: KyotoReader):
    document = fixture_kyoto_reader.process_document('w201106-0000060560')
    entities: Dict[int, Entity] = document.entities
    assert len(entities) == 15

    entity: Entity = entities[14]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 4
    assert (mentions[0].core, mentions[0].dtid, mentions[0].eids) == ('ドクター',
                                                                      7, {4})
    assert (mentions[1].core, mentions[1].dtid, mentions[1].eids) == ('ドクター',
                                                                      11, {14})
    assert (mentions[2].core, mentions[2].dtid, mentions[2].eids) == ('ドクター',
                                                                      16, {14})
    assert (mentions[3].core, mentions[3].dtid, mentions[3].eids) == ('皆様', 17,
                                                                      {14})
예제 #13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-dir',
                        '-i',
                        default=None,
                        type=str,
                        help='path to input knp directory')
    parser.add_argument('--output-dir',
                        '-o',
                        default=None,
                        type=str,
                        help='path to output knp directory')
    args = parser.parse_args()

    docs: Dict[str, str] = {}
    for path in Path(args.input_dir).glob('**/*.knp'):
        docs.update(KyotoReader.read_knp(path, did_from_sid=True))

    for did, knp_string in docs.items():
        out_path = Path(args.output_dir) / f'{did}.knp'
        with out_path.open(mode='w') as f:
            f.write(knp_string)
예제 #14
0
def show(args: argparse.Namespace):
    reader = KyotoReader(args.path, target_cases=args.cases)
    for document in reader.process_all_documents():
        document.draw_tree()
예제 #15
0
def test_pas(fixture_kyoto_reader: KyotoReader):
    document = fixture_kyoto_reader.process_document('w201106-0000060050')
    predicates: List[Predicate] = document.get_predicates()
    assert len(predicates) == 12

    sid1 = 'w201106-0000060050-1'
    sid2 = 'w201106-0000060050-2'
    sid3 = 'w201106-0000060050-3'

    arguments = document.get_arguments(predicates[0])
    assert predicates[0].core == 'トス'
    assert len([_ for args in arguments.values() for _ in args]) == 2
    arg = arguments['ガ'][0]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('不特定:人', 0, 'exo', '')
    arg = arguments['ヲ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('コイン', 0, 0, sid1, 'dep', '')

    arguments = document.get_arguments(predicates[1])
    assert predicates[1].core == '行う'
    assert len([_ for args in arguments.values() for _ in args]) == 4
    arg = arguments['ガ'][0]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('不特定:人', 2, 'exo', '')
    arg = arguments['ガ'][1]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('読者', 3, 'exo', '?')
    arg = arguments['ガ'][2]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('著者', 4, 'exo', '?')
    arg = arguments['ヲ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('トス', 1, 1, sid1, 'overt', '')

    arguments = document.get_arguments(predicates[2])
    assert predicates[2].core == '表'
    assert len([_ for args in arguments.values() for _ in args]) == 1
    arg = arguments['ノ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('コイン', 0, 0, sid1, 'inter', '')

    arguments = document.get_arguments(predicates[3])
    assert predicates[3].core == '出た'
    assert len([_ for args in arguments.values() for _ in args]) == 2
    arg = arguments['ガ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('表', 0, 4, sid2, 'overt', '')
    arg = arguments['外の関係'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('数', 2, 6, sid2, 'dep', '')

    arguments = document.get_arguments(predicates[4])
    assert predicates[4].core == '数'
    assert len([_ for args in arguments.values() for _ in args]) == 1
    arg = arguments['ノ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('出た', 1, 5, sid2, 'dep', '')

    arguments = document.get_arguments(predicates[5])
    assert predicates[5].core == 'モンスター'
    assert len([_ for args in arguments.values() for _ in args]) == 2
    arg = arguments['修飾'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('フィールド上', 3, 7, sid2, 'dep', '')
    arg = arguments['修飾'][1]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('数', 2, 6, sid2, 'intra', 'AND')

    arguments = document.get_arguments(predicates[6])
    assert predicates[6].core == '破壊する'
    assert len([_ for args in arguments.values() for _ in args]) == 2
    arg = arguments['ガ'][0]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('不特定:状況', 11, 'exo', '')
    arg = arguments['ヲ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('モンスター', 4, 8, sid2, 'overt', '')

    arguments = document.get_arguments(predicates[7])
    assert predicates[7].core == '効果'
    assert len([_ for args in arguments.values() for _ in args]) == 1
    arg = arguments['トイウ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('破壊する', 5, 9, sid2, 'inter', '')

    arguments = document.get_arguments(predicates[8])
    assert predicates[8].core == '1度'
    assert len([_ for args in arguments.values() for _ in args]) == 1
    arg = arguments['ニ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('ターン', 3, 13, sid3, 'overt', '')

    arguments = document.get_arguments(predicates[9])
    assert predicates[9].core == 'メイン'
    assert len([_ for args in arguments.values() for _ in args]) == 1
    arg = arguments['ガ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('フェイズ', 7, 17, sid3, 'dep', '')

    arguments = document.get_arguments(predicates[10])
    assert predicates[10].core == 'フェイズ'
    assert len([_ for args in arguments.values() for _ in args]) == 1
    arg = arguments['ノ?'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('自分', 5, 15, sid3, 'overt', '')

    arguments = document.get_arguments(predicates[11])
    assert predicates[11].core == '使用する事ができる'
    assert len([_ for args in arguments.values() for _ in args]) == 5
    arg = arguments['ガ'][0]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('不特定:人', 17, 'exo', '')
    arg = arguments['ガ'][1]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('著者', 4, 'exo', '?')
    arg = arguments['ガ'][2]
    assert isinstance(arg, SpecialArgument)
    assert tuple(arg) == ('読者', 3, 'exo', '?')
    arg = arguments['ヲ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('効果', 1, 11, sid3, 'dep', '')
    arg = arguments['ニ'][0]
    assert isinstance(arg, Argument)
    assert tuple(arg) == ('フェイズ', 7, 17, sid3, 'overt', '')
예제 #16
0
def show(args: argparse.Namespace):
    """Show the specified document in a tree format."""
    reader = KyotoReader(args.path, target_cases=args.cases)
    for document in reader.process_all_documents():
        document.draw_tree()
예제 #17
0
def test_coref1(fixture_kyoto_reader: KyotoReader):
    document = fixture_kyoto_reader.process_document('w201106-0000060050')
    entities: Dict[int, Entity] = document.entities
    assert len(entities) == 19

    entity = entities[0]
    assert (entity.taigen, entity.yougen) == (None, None)
    assert entity.exophor == '不特定:人'
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 0

    entity = entities[1]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('コイン', 0)
    assert mentions[0].eids == {1}

    entity = entities[2]
    assert (entity.taigen, entity.yougen) == (None, None)
    assert entity.exophor == '不特定:人'
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 0

    entity = entities[3]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor == '読者'
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('自分', 15)
    assert mentions[0].eids == {14}
    assert mentions[0].eids_unc == {3, 4, 15}

    entity = entities[4]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor == '著者'
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('自分', 15)
    assert mentions[0].eids == {14}
    assert mentions[0].eids_unc == {3, 4, 15}

    entity = entities[5]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('トス', 1)
    assert mentions[0].eids == {5}

    entity = entities[6]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('表', 4)
    assert mentions[0].eids == {6}

    entity = entities[7]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('数', 6)
    assert mentions[0].eids == {7}

    entity = entities[8]
    assert (entity.taigen, entity.yougen) == (False, True)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('出た', 5)
    assert mentions[0].eids == {8}

    entity = entities[9]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('フィールド上', 7)
    assert mentions[0].eids == {9}

    entity = entities[10]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('モンスター', 8)
    assert mentions[0].eids == {10}

    entity = entities[11]
    assert (entity.taigen, entity.yougen) == (None, None)
    assert entity.exophor == '不特定:状況'
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 0

    entity = entities[12]
    assert (entity.taigen, entity.yougen) == (False, True)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('破壊する', 9)
    assert mentions[0].eids == {12}

    entity = entities[13]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('ターン', 13)
    assert mentions[0].eids == {13}

    entity = entities[14]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('自分', 15)
    assert mentions[0].eids == {14}

    entity = entities[15]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor == '不特定:人'
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('自分', 15)
    assert mentions[0].eids == {14}

    entity = entities[16]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('フェイズ', 17)
    assert mentions[0].eids == {16}

    entity = entities[17]
    assert (entity.taigen, entity.yougen) == (None, None)
    assert entity.exophor == '不特定:人'
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 0

    entity = entities[18]
    assert (entity.taigen, entity.yougen) == (True, False)
    assert entity.exophor is None
    mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid)
    assert len(mentions) == 1
    assert (mentions[0].core, mentions[0].dtid) == ('効果', 11)
    assert mentions[0].eids == {18}
예제 #18
0
def list_(args: argparse.Namespace):
    """List document IDs which specified path contains."""
    reader = KyotoReader(args.path)
    print('\n'.join(reader.doc_ids))
예제 #19
0
def fixture_documents_gold():
    reader = KyotoReader(here / 'data' / 'gold',
                         target_cases=ALL_CASES,
                         target_corefs=ALL_COREFS)
    yield reader.process_all_documents()
예제 #20
0
def list_(args: argparse.Namespace):
    reader = KyotoReader(args.path)
    print('\n'.join(reader.doc_ids))
예제 #21
0
class PASDataset(Dataset):
    def __init__(
        self,
        path: Union[str, Path],
        cases: List[str],
        exophors: List[str],
        coreference: bool,
        bridging: bool,
        max_seq_length: int,
        bert_path: Union[str, Path],
        training: bool,
        kc: bool,
        train_targets: List[str],
        pas_targets: List[str],
        n_jobs: int = -1,
        logger=None,
        gold_path: Optional[str] = None,
    ) -> None:
        self.path = Path(path)
        self.reader = KyotoReader(self.path, extract_nes=False, n_jobs=n_jobs)
        self.target_cases: List[str] = [
            c for c in cases if c in self.reader.target_cases and c != 'ノ'
        ]
        self.target_exophors: List[str] = [
            e for e in exophors if e in ALL_EXOPHORS
        ]
        self.coreference: bool = coreference
        self.bridging: bool = bridging
        self.relations = self.target_cases + ['ノ'] * bridging + [
            '='
        ] * coreference
        self.kc: bool = kc
        self.train_targets: List[str] = [
            t if t != 'case' else 'dep' for t in train_targets
        ]  # backward compatibility
        self.pas_targets: List[str] = pas_targets
        self.logger: Logger = logger or logging.getLogger(__file__)
        special_tokens = self.target_exophors + ['NULL'] + (
            ['NA'] if coreference else [])
        self.special_to_index: Dict[str, int] = {
            token: max_seq_length - i - 1
            for i, token in enumerate(reversed(special_tokens))
        }
        self.tokenizer = BertTokenizer.from_pretrained(
            bert_path, do_lower_case=False, tokenize_chinese_chars=False)
        self.max_seq_length: int = max_seq_length
        self.bert_path: Path = Path(bert_path)
        documents = list(self.reader.process_all_documents())

        if not training:
            self.documents: Optional[List[Document]] = documents
            if gold_path is not None:
                reader = KyotoReader(Path(gold_path),
                                     extract_nes=False,
                                     n_jobs=n_jobs)
                self.gold_documents = list(reader.process_all_documents())

        self.examples = self._load(documents, str(path))

    def _load(self, documents: List[Document], path: str) -> List[PasExample]:
        examples: List[PasExample] = []
        load_cache: bool = ('BPA_DISABLE_CACHE' not in os.environ
                            and 'BPA_OVERWRITE_CACHE' not in os.environ)
        save_cache: bool = ('BPA_DISABLE_CACHE' not in os.environ)
        bpa_cache_dir: Path = Path(
            os.environ.get('BPA_CACHE_DIR',
                           f'/tmp/{os.environ["USER"]}/bpa_cache'))
        for document in tqdm(documents, desc='processing documents'):
            hash_ = self._hash(document, path, self.relations,
                               self.target_exophors, self.kc, self.pas_targets,
                               self.train_targets, str(self.bert_path))
            example_cache_path = bpa_cache_dir / hash_ / f'{document.doc_id}.pkl'
            if example_cache_path.exists() and load_cache:
                with example_cache_path.open('rb') as f:
                    example = cPickle.load(f)
            else:
                example = PasExample()
                example.load(document,
                             cases=self.target_cases,
                             exophors=self.target_exophors,
                             coreference=self.coreference,
                             bridging=self.bridging,
                             relations=self.relations,
                             kc=self.kc,
                             pas_targets=self.pas_targets,
                             tokenizer=self.tokenizer)
                if save_cache:
                    example_cache_path.parent.mkdir(exist_ok=True,
                                                    parents=True)
                    with example_cache_path.open('wb') as f:
                        cPickle.dump(example, f)

            # ignore too long document
            if len(example.tokens) > self.max_seq_length - len(
                    self.special_to_index):
                continue

            examples.append(example)
        if len(examples) == 0:
            self.logger.error(
                'No examples to process. '
                f'Make sure there exist any documents in {self.path} and they are not too long.'
            )
        return examples

    @staticmethod
    def _hash(document, *args) -> str:
        attrs = ('cases', 'corefs', 'relax_cases', 'extract_nes',
                 'use_pas_tag')
        assert set(attrs) <= set(vars(document).keys())
        vars_document = {k: v for k, v in vars(document).items() if k in attrs}
        string = repr(sorted(vars_document)) + ''.join(repr(a) for a in args)
        return hashlib.md5(string.encode()).hexdigest()

    def _convert_example_to_feature(
        self,
        example: PasExample,
    ) -> InputFeatures:
        """Loads a data file into a list of `InputBatch`s."""

        vocab_size = self.tokenizer.vocab_size
        max_seq_length = self.max_seq_length
        num_special_tokens = len(self.special_to_index)
        num_relations = len(self.relations)

        tokens = example.tokens
        tok_to_orig_index = example.tok_to_orig_index
        orig_to_tok_index = example.orig_to_tok_index

        arguments_set: List[List[List[int]]] = []
        candidates_set: List[List[List[int]]] = []
        overts_set: List[List[List[int]]] = []
        deps: List[List[int]] = []

        # subword loop
        for token, orig_index in zip(tokens, tok_to_orig_index):

            if orig_index is None:
                deps.append([0] * max_seq_length)
            else:
                ddep = example.ddeps[orig_index]  # orig_index の係り先の dtid
                # orig_index の係り先になっている基本句を構成する全てのトークンに1が立つ
                deps.append([
                    (0 if idx is None or ddep != example.dtids[idx] else 1)
                    for idx in tok_to_orig_index
                ])
                deps[-1] += [0] * (max_seq_length - len(tok_to_orig_index))

            # subsequent subword or [CLS] token or [SEP] token
            if token.startswith("##") or orig_index is None:
                arguments_set.append([[] for _ in range(num_relations)])
                overts_set.append([[] for _ in range(num_relations)])
                candidates_set.append([[] for _ in range(num_relations)])
                continue

            arguments: List[List[int]] = [[] for _ in range(num_relations)]
            overts: List[List[int]] = [[] for _ in range(num_relations)]
            for i, (rel, arg_strings) in enumerate(
                    example.arguments_set[orig_index].items()):
                for arg_string in arg_strings:
                    # arg_string: 著者, 8%C, 15%O, 2, NULL, ...
                    flag = None
                    if arg_string[-2:] in ('%C', '%N', '%O'):
                        flag = arg_string[-1]
                        arg_string = arg_string[:-2]
                    if arg_string in self.special_to_index:
                        tok_index = self.special_to_index[arg_string]
                    else:
                        tok_index = orig_to_tok_index[int(arg_string)]
                    if rel in self.target_cases:
                        if arg_string in self.target_exophors and 'zero' not in self.train_targets:
                            continue
                        if flag == 'C':
                            overts[i].append(tok_index)
                        if (flag == 'C' and 'overt' not in self.train_targets) or \
                           (flag == 'N' and 'dep' not in self.train_targets) or \
                           (flag == 'O' and 'zero' not in self.train_targets):
                            continue
                    arguments[i].append(tok_index)

            arguments_set.append(arguments)
            overts_set.append(overts)

            # 助詞などに対しても特殊トークンを candidates として加える
            candidates: List[List[int]] = []
            for rel in self.relations:
                if rel != '=':
                    cands = [
                        orig_to_tok_index[dmid]
                        for dmid in example.arg_candidates_set[orig_index]
                    ]
                    specials = self.target_exophors + ['NULL']
                else:
                    cands = [
                        orig_to_tok_index[dmid]
                        for dmid in example.ment_candidates_set[orig_index]
                    ]
                    specials = self.target_exophors + ['NA']
                cands += [
                    self.special_to_index[special] for special in specials
                ]
                candidates.append(cands)
            candidates_set.append(candidates)

        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
        input_mask = [True] * len(input_ids)

        # Zero-pad up to the sequence length
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(False)
            arguments_set.append([[]] * num_relations)
            overts_set.append([[]] * num_relations)
            candidates_set.append([[]] * num_relations)
            deps.append([0] * max_seq_length)

        # special tokens
        for i in range(num_special_tokens):
            pos = max_seq_length - num_special_tokens + i
            input_ids[pos] = vocab_size + i
            input_mask[pos] = True

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(arguments_set) == max_seq_length
        assert len(overts_set) == max_seq_length
        assert len(candidates_set) == max_seq_length
        assert len(deps) == max_seq_length

        feature = InputFeatures(
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=[0] * max_seq_length,
            arguments_set=[[[int(x in args) for x in range(max_seq_length)]
                            for args in arguments]
                           for arguments in arguments_set],
            overt_mask=[[[(x in overt) for x in range(max_seq_length)]
                         for overt in overts] for overts in overts_set],
            ng_token_mask=[[[(x in cands) for x in range(max_seq_length)]
                            for cands in candidates]
                           for candidates in candidates_set
                           ],  # False -> mask, True -> keep
            deps=deps,
        )

        return feature

    def stat(self) -> dict:
        n_mentions = 0
        pa: Dict[str, Union[int, dict]] = defaultdict(int)
        bar: Dict[str, Union[int, dict]] = defaultdict(int)
        cr: Dict[str, Union[int, dict]] = defaultdict(int)
        n_args_bar = defaultdict(int)
        n_args_pa = defaultdict(lambda: defaultdict(int))

        for arguments in (x for example in self.examples
                          for x in example.arguments_set):
            for case, args in arguments.items():
                if not args:
                    continue
                arg: str = args[0]
                if case == '=':
                    if arg == 'NA':
                        cr['na'] += 1
                        continue
                    n_mentions += 1
                    if arg in self.target_exophors:
                        cr['exo'] += 1
                    else:
                        cr['ana'] += 1
                else:
                    n_args = n_args_bar if case == 'ノ' else n_args_pa[case]
                    if arg == 'NULL':
                        n_args['null'] += 1
                        continue
                    n_args['all'] += 1
                    if arg in self.target_exophors:
                        n_args['exo'] += 1
                    elif '%C' in arg:
                        n_args['overt'] += 1
                    elif '%N' in arg:
                        n_args['dep'] += 1
                    elif '%O' in arg:
                        n_args['zero'] += 1

            arguments_: List[List[str]] = list(arguments.values())
            if self.coreference:
                if arguments_[-1]:
                    cr['mentions_all'] += 1
                if [arg for arg in arguments_[-1] if arg != 'NA']:
                    cr['mentions_tagged'] += 1
                arguments_ = arguments_[:-1]
            if self.bridging:
                if arguments_[-1]:
                    bar['preds_all'] += 1
                if [arg for arg in arguments_[-1] if arg != 'NULL']:
                    bar['preds_tagged'] += 1
                arguments_ = arguments_[:-1]
            if any(arguments_):
                pa['preds_all'] += 1
            if [arg for args in arguments_ for arg in args if arg != 'NULL']:
                pa['preds_tagged'] += 1

        n_args_pa_all = defaultdict(int)
        for case, ans in n_args_pa.items():
            for anal, num in ans.items():
                n_args_pa_all[anal] += num
        n_args_pa['all'] = n_args_pa_all
        pa['args'] = n_args_pa
        bar['args'] = n_args_bar
        cr['mentions'] = n_mentions

        return {
            'examples':
            len(self.examples),
            'pas':
            pa,
            'bridging':
            bar,
            'coreference':
            cr,
            'sentences':
            sum(len(doc) for doc in self.gold_documents)
            if self.gold_documents else None,
            'bps':
            sum(len(doc.bp_list()) for doc in self.gold_documents)
            if self.gold_documents else None,
            'tokens':
            sum(len(example.tokens) - 2 for example in self.examples),
        }

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx) -> tuple:
        feature = self._convert_example_to_feature(self.examples[idx])
        input_ids = np.array(feature.input_ids)  # (seq)
        attention_mask = np.array(feature.input_mask)  # (seq)
        segment_ids = np.array(feature.segment_ids)  # (seq)
        arguments_ids = np.array(feature.arguments_set)  # (seq, case, seq)
        overt_mask = np.array(feature.overt_mask)  # (seq, case, seq)
        ng_token_mask = np.array(feature.ng_token_mask)  # (seq, case, seq)
        deps = np.array(feature.deps)  # (seq, seq)
        task = np.array(TASK_ID['pa'])  # ()
        return input_ids, attention_mask, segment_ids, ng_token_mask, arguments_ids, deps, task, overt_mask
예제 #22
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--prediction-dir',
        default=None,
        type=str,
        help=
        'path to directory where system output KWDLC files exist (default: None)'
    )
    parser.add_argument(
        '--gold-dir',
        default=None,
        type=str,
        help='path to directory where gold KWDLC files exist (default: None)')
    parser.add_argument('--coreference',
                        '--coref',
                        '--cr',
                        action='store_true',
                        default=False,
                        help='perform coreference resolution')
    parser.add_argument('--bridging',
                        '--brg',
                        '--bar',
                        action='store_true',
                        default=False,
                        help='perform bridging anaphora resolution')
    parser.add_argument('--case-string',
                        type=str,
                        default='ガ,ヲ,ニ,ガ2',
                        help='case strings separated by ","')
    parser.add_argument('--exophors',
                        '--exo',
                        type=str,
                        default='著者,読者,不特定:人,不特定:物',
                        help='exophor strings separated by ","')
    parser.add_argument(
        '--read-prediction-from-pas-tag',
        action='store_true',
        default=False,
        help='use <述語項構造:> tag instead of <rel > tag in prediction files')
    parser.add_argument(
        '--pas-target',
        choices=['', 'pred', 'noun', 'all'],
        default='pred',
        help=
        'PAS analysis evaluation target (pred: verbal predicates, noun: nominal predicates)'
    )
    parser.add_argument(
        '--result-html',
        default=None,
        type=str,
        help=
        'path to html file which prediction result is exported (default: None)'
    )
    parser.add_argument(
        '--result-csv',
        default=None,
        type=str,
        help=
        'path to csv file which prediction result is exported (default: None)')
    args = parser.parse_args()

    reader_gold = KyotoReader(Path(args.gold_dir),
                              extract_nes=False,
                              use_pas_tag=False)
    reader_pred = KyotoReader(
        Path(args.prediction_dir),
        extract_nes=False,
        use_pas_tag=args.read_prediction_from_pas_tag,
    )
    documents_pred = reader_pred.process_all_documents()
    documents_gold = reader_gold.process_all_documents()

    assert set(args.case_string.split(',')) <= set(CASE2YOMI.keys())
    msg = '"ノ" found in case string. If you want to perform bridging anaphora resolution, specify "--bridging" ' \
          'option instead'
    assert 'ノ' not in args.case_string.split(','), msg
    scorer = Scorer(documents_pred,
                    documents_gold,
                    target_cases=args.case_string.split(','),
                    target_exophors=args.exophors.split(','),
                    coreference=args.coreference,
                    bridging=args.bridging,
                    pas_target=args.pas_target)
    result = scorer.run()
    if args.result_html:
        scorer.write_html(Path(args.result_html))
    if args.result_csv:
        result.export_csv(args.result_csv)
    result.export_txt(sys.stdout)
예제 #23
0
def fixture_kyoto_reader():
    reader = KyotoReader(data_dir / 'knp',
                         target_cases=ALL_CASES,
                         target_corefs=ALL_COREFS)
    yield reader