def process(input_path: Path, output_path: Path, corpus: str) -> int: output_path.mkdir(exist_ok=True) reader = KyotoReader(input_path, extract_nes=False) for document in tqdm(reader.process_all_documents(backend="multiprocessing"), desc=corpus, total=len(reader)): with output_path.joinpath(document.doc_id + '.pkl').open(mode='wb') as f: cPickle.dump(document, f) return len(reader)
def __init__( self, path: Union[str, Path], cases: List[str], exophors: List[str], coreference: bool, bridging: bool, max_seq_length: int, bert_path: Union[str, Path], training: bool, kc: bool, train_targets: List[str], pas_targets: List[str], n_jobs: int = -1, logger=None, gold_path: Optional[str] = None, ) -> None: self.path = Path(path) self.reader = KyotoReader(self.path, extract_nes=False, n_jobs=n_jobs) self.target_cases: List[str] = [ c for c in cases if c in self.reader.target_cases and c != 'ノ' ] self.target_exophors: List[str] = [ e for e in exophors if e in ALL_EXOPHORS ] self.coreference: bool = coreference self.bridging: bool = bridging self.relations = self.target_cases + ['ノ'] * bridging + [ '=' ] * coreference self.kc: bool = kc self.train_targets: List[str] = [ t if t != 'case' else 'dep' for t in train_targets ] # backward compatibility self.pas_targets: List[str] = pas_targets self.logger: Logger = logger or logging.getLogger(__file__) special_tokens = self.target_exophors + ['NULL'] + ( ['NA'] if coreference else []) self.special_to_index: Dict[str, int] = { token: max_seq_length - i - 1 for i, token in enumerate(reversed(special_tokens)) } self.tokenizer = BertTokenizer.from_pretrained( bert_path, do_lower_case=False, tokenize_chinese_chars=False) self.max_seq_length: int = max_seq_length self.bert_path: Path = Path(bert_path) documents = list(self.reader.process_all_documents()) if not training: self.documents: Optional[List[Document]] = documents if gold_path is not None: reader = KyotoReader(Path(gold_path), extract_nes=False, n_jobs=n_jobs) self.gold_documents = list(reader.process_all_documents()) self.examples = self._load(documents, str(path))
def __init__( self, path: Union[str, Path], cases: List[str], exophors: List[str], coreference: bool, bridging: bool, max_seq_length: int, model_name: str, pretrained_path: Union[str, Path], training: bool, kc: bool, train_targets: List[str], pas_targets: List[str], logger=None, kc_joined_path: Optional[str] = None, ) -> None: self.reader = KyotoReader(Path(path), extract_nes=False) self.target_cases: List[str] = [ c for c in cases if c in self.reader.target_cases and c != 'ノ' ] self.target_exophors: List[str] = [ e for e in exophors if e in ALL_EXOPHORS ] self.coreference: bool = coreference self.bridging: bool = bridging self.kc: bool = kc self.train_targets: List[str] = train_targets self.pas_targets: List[str] = pas_targets self.logger: Logger = logger or logging.getLogger(__file__) special_tokens = self.target_exophors + ['NULL'] + ( ['NA'] if coreference else []) self.special_to_index: Dict[str, int] = { token: max_seq_length - i - 1 for i, token in enumerate(reversed(special_tokens)) } tokenizer_cls = MODEL2TOKENIZER[model_name] self.tokenizer = tokenizer_cls.from_pretrained( pretrained_path, do_lower_case=False, tokenize_chinese_chars=False) self.max_seq_length: int = max_seq_length self.pretrained_path: Path = Path(pretrained_path) documents = list( self.reader.process_all_documents(backend="multiprocessing")) self.documents: Optional[ List[Document]] = documents if not training else None if self.kc and not training: assert kc_joined_path is not None reader = KyotoReader(Path(kc_joined_path), extract_nes=False) self.joined_documents = list( reader.process_all_documents(backend="multiprocessing")) self.examples = self._load(documents, str(path))
def test_pas_relax(fixture_kyoto_reader: KyotoReader): document = fixture_kyoto_reader.process_document('w201106-0000060560') predicates: List[Predicate] = document.get_predicates() arguments = document.get_arguments(predicates[9], relax=True) sid1 = 'w201106-0000060560-1' sid2 = 'w201106-0000060560-2' sid3 = 'w201106-0000060560-3' assert predicates[9].core == 'ご協力' assert len([_ for args in arguments.values() for _ in args]) == 6 args = sorted(arguments['ガ'], key=lambda a: a.dtid) arg = args[0] assert isinstance(arg, Argument) assert tuple(arg) == ('ドクター', 7, 7, sid1, 'inter', 'AND') arg = args[1] assert isinstance(arg, Argument) assert tuple(arg) == ('ドクター', 2, 11, sid2, 'inter', 'AND') arg = args[2] assert isinstance(arg, Argument) assert tuple(arg) == ('ドクター', 0, 16, sid3, 'intra', 'AND') arg = args[3] assert isinstance(arg, Argument) assert tuple(arg) == ('皆様', 1, 17, sid3, 'intra', '') args = sorted(arguments['ニ'], key=lambda a: str(a)) arg = args[0] assert isinstance(arg, Argument) assert tuple(arg) == ('コーナー', 5, 14, sid2, 'inter', '?') arg = args[1] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('著者', 5, 'exo', '')
def test_dep_type(fixture_kyoto_reader: KyotoReader): document = fixture_kyoto_reader.process_document('w201106-0002000028') predicate: Predicate = document.get_predicates()[4] assert predicate.core == '同じ' arg = document.get_arguments(predicate, relax=True)['ガ'][1] assert isinstance(arg, Argument) assert arg.core == 'フランス' assert arg.dtid == 10 assert arg.dep_type == 'dep'
def process( input_path: Path, output_path: Path, corpus: str, do_reparse: bool, n_jobs: int, bertknp: Optional[str], knp: KNP, keep_dep: bool = False, split: bool = False, max_subword_length: int = None, tokenizer: BertTokenizer = None, ) -> int: with tempfile.TemporaryDirectory( ) as tmp_dir1, tempfile.TemporaryDirectory() as tmp_dir2: tmp_dir1, tmp_dir2 = Path(tmp_dir1), Path(tmp_dir2) if do_reparse is True: reparse(input_path, tmp_dir1, knp, bertknp=bertknp, n_jobs=n_jobs, keep_dep=keep_dep) input_path = tmp_dir1 if split is True: # Because the length of the documents in KyotoCorpus is very long, split them into multiple documents # so that the tail sentence of each document has as much preceding sentences as possible. print('splitting corpus...') split_kc(input_path, tmp_dir2, max_subword_length, tokenizer) input_path = tmp_dir2 output_path.mkdir(exist_ok=True) reader = KyotoReader(input_path, extract_nes=False, did_from_sid=(not split), n_jobs=n_jobs) for document in tqdm(reader.process_all_documents(), desc=corpus, total=len(reader)): with output_path.joinpath(document.doc_id + '.pkl').open(mode='wb') as f: cPickle.dump(document, f) return len(reader)
def test_coref_link3(fixture_kyoto_reader: KyotoReader): document = fixture_kyoto_reader.process_document('w201106-0000060877') for entity in document.entities.values(): for mention in entity.mentions: assert entity.eid in mention.eids for mention in entity.mentions_unc: assert entity.eid in mention.eids_unc for mention in document.mentions.values(): for eid in mention.eids: assert mention in document.entities[eid].mentions for eid in mention.eids_unc: assert mention in document.entities[eid].mentions_unc
def main(): reader = KyotoReader(sys.argv[1]) ret = RetValue() for doc in reader.process_all_documents(): ret += coverage(doc) print('pred:') print( f' precision: {ret.measure_pred.precision:.4f} ({ret.measure_pred.correct}/{ret.measure_pred.denom_pred})' ) print( f' recall : {ret.measure_pred.recall:.4f} ({ret.measure_pred.correct}/{ret.measure_pred.denom_gold})' ) print(f' F : {ret.measure_pred.f1:.4f}') print('noun:') print( f' precision: {ret.measure_noun.precision:.4f} ({ret.measure_noun.correct}/{ret.measure_noun.denom_pred})' ) print( f' recall : {ret.measure_noun.recall:.4f} ({ret.measure_noun.correct}/{ret.measure_noun.denom_gold})' ) print(f' F : {ret.measure_noun.f1:.4f}')
def process_kc(input_path: Path, output_path: Path, max_subword_length: int, tokenizer: TokenizeHandlerMeta, split: bool = False ) -> int: with tempfile.TemporaryDirectory() as tmp_dir: if split: tmp_dir = Path(tmp_dir) # 京大コーパスは1文書が長いのでできるだけ多くの context を含むように複数文書に分割する print('splitting kc...') split_kc(input_path, tmp_dir, max_subword_length, tokenizer) input_path = tmp_dir print(list(input_path.iterdir())) output_path.mkdir(exist_ok=True) reader = KyotoReader(input_path, extract_nes=False, did_from_sid=False) for document in tqdm(reader.process_all_documents(backend="multiprocessing"), desc='kc', total=len(reader)): with output_path.joinpath(document.doc_id + '.pkl').open(mode='wb') as f: cPickle.dump(document, f) return len(reader)
def reparse( input_dir: Path, output_dir: Path, knp: KNP, bertknp: Optional[str] = None, n_jobs: int = 0, keep_dep: bool = False, ) -> None: if bertknp is None: args_iter = ((path, output_dir, knp, keep_dep) for path in input_dir.glob('*.knp')) if n_jobs > 0: with Pool(n_jobs) as pool: pool.starmap(reparse_knp, args_iter) else: for args in args_iter: reparse_knp(*args) return assert keep_dep is False, 'If you use BERTKNP, you cannot keep dependency labels.' buff = '' for knp_file in input_dir.glob('*.knp'): with knp_file.open() as fin: for line in fin: if line.startswith('+') or line.startswith('*'): buff += line[0] + '\n' else: buff += line out = subprocess.run([ bertknp, '-p', Path(bertknp).parents[1] / '.venv/bin/python', '-O', Path(__file__).parent.joinpath( 'bertknp_options.txt').resolve().__str__(), '-tab', ], input=buff, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8') logger.warning(out.stderr) with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) tmp_dir.joinpath('tmp.knp').write_text(out.stdout) for did, knp_string in KyotoReader(tmp_dir / 'tmp.knp').did2knps.items(): output_dir.joinpath(f'{did}.knp').write_text(knp_string)
def test_ne(fixture_kyoto_reader: KyotoReader): document = fixture_kyoto_reader.process_document('w201106-0000060877') nes = document.named_entities assert len(nes) == 2 ne = nes[0] assert (ne.category, ne.name, ne.dmid_range) == ('ORGANIZATION', '柏市ひまわり園', range(5, 9)) ne = nes[1] assert (ne.category, ne.name, ne.dmid_range) == ('DATE', '平成23年度', range(11, 14)) document = fixture_kyoto_reader.process_document('w201106-0000074273') nes = document.named_entities assert len(nes) == 3 ne = nes[0] assert (ne.category, ne.name, ne.dmid_range) == ('LOCATION', 'ダーマ神殿', range(15, 17)) ne = nes[1] assert (ne.category, ne.name, ne.dmid_range) == ('ARTIFACT', '天の箱舟', range(24, 27)) ne = nes[2] assert (ne.category, ne.name, ne.dmid_range) == ('LOCATION', 'ナザム村', range(39, 41))
def test_coref2(fixture_kyoto_reader: KyotoReader): document = fixture_kyoto_reader.process_document('w201106-0000060560') entities: Dict[int, Entity] = document.entities assert len(entities) == 15 entity: Entity = entities[14] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 4 assert (mentions[0].core, mentions[0].dtid, mentions[0].eids) == ('ドクター', 7, {4}) assert (mentions[1].core, mentions[1].dtid, mentions[1].eids) == ('ドクター', 11, {14}) assert (mentions[2].core, mentions[2].dtid, mentions[2].eids) == ('ドクター', 16, {14}) assert (mentions[3].core, mentions[3].dtid, mentions[3].eids) == ('皆様', 17, {14})
def main(): parser = argparse.ArgumentParser() parser.add_argument('--input-dir', '-i', default=None, type=str, help='path to input knp directory') parser.add_argument('--output-dir', '-o', default=None, type=str, help='path to output knp directory') args = parser.parse_args() docs: Dict[str, str] = {} for path in Path(args.input_dir).glob('**/*.knp'): docs.update(KyotoReader.read_knp(path, did_from_sid=True)) for did, knp_string in docs.items(): out_path = Path(args.output_dir) / f'{did}.knp' with out_path.open(mode='w') as f: f.write(knp_string)
def show(args: argparse.Namespace): reader = KyotoReader(args.path, target_cases=args.cases) for document in reader.process_all_documents(): document.draw_tree()
def test_pas(fixture_kyoto_reader: KyotoReader): document = fixture_kyoto_reader.process_document('w201106-0000060050') predicates: List[Predicate] = document.get_predicates() assert len(predicates) == 12 sid1 = 'w201106-0000060050-1' sid2 = 'w201106-0000060050-2' sid3 = 'w201106-0000060050-3' arguments = document.get_arguments(predicates[0]) assert predicates[0].core == 'トス' assert len([_ for args in arguments.values() for _ in args]) == 2 arg = arguments['ガ'][0] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('不特定:人', 0, 'exo', '') arg = arguments['ヲ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('コイン', 0, 0, sid1, 'dep', '') arguments = document.get_arguments(predicates[1]) assert predicates[1].core == '行う' assert len([_ for args in arguments.values() for _ in args]) == 4 arg = arguments['ガ'][0] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('不特定:人', 2, 'exo', '') arg = arguments['ガ'][1] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('読者', 3, 'exo', '?') arg = arguments['ガ'][2] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('著者', 4, 'exo', '?') arg = arguments['ヲ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('トス', 1, 1, sid1, 'overt', '') arguments = document.get_arguments(predicates[2]) assert predicates[2].core == '表' assert len([_ for args in arguments.values() for _ in args]) == 1 arg = arguments['ノ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('コイン', 0, 0, sid1, 'inter', '') arguments = document.get_arguments(predicates[3]) assert predicates[3].core == '出た' assert len([_ for args in arguments.values() for _ in args]) == 2 arg = arguments['ガ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('表', 0, 4, sid2, 'overt', '') arg = arguments['外の関係'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('数', 2, 6, sid2, 'dep', '') arguments = document.get_arguments(predicates[4]) assert predicates[4].core == '数' assert len([_ for args in arguments.values() for _ in args]) == 1 arg = arguments['ノ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('出た', 1, 5, sid2, 'dep', '') arguments = document.get_arguments(predicates[5]) assert predicates[5].core == 'モンスター' assert len([_ for args in arguments.values() for _ in args]) == 2 arg = arguments['修飾'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('フィールド上', 3, 7, sid2, 'dep', '') arg = arguments['修飾'][1] assert isinstance(arg, Argument) assert tuple(arg) == ('数', 2, 6, sid2, 'intra', 'AND') arguments = document.get_arguments(predicates[6]) assert predicates[6].core == '破壊する' assert len([_ for args in arguments.values() for _ in args]) == 2 arg = arguments['ガ'][0] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('不特定:状況', 11, 'exo', '') arg = arguments['ヲ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('モンスター', 4, 8, sid2, 'overt', '') arguments = document.get_arguments(predicates[7]) assert predicates[7].core == '効果' assert len([_ for args in arguments.values() for _ in args]) == 1 arg = arguments['トイウ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('破壊する', 5, 9, sid2, 'inter', '') arguments = document.get_arguments(predicates[8]) assert predicates[8].core == '1度' assert len([_ for args in arguments.values() for _ in args]) == 1 arg = arguments['ニ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('ターン', 3, 13, sid3, 'overt', '') arguments = document.get_arguments(predicates[9]) assert predicates[9].core == 'メイン' assert len([_ for args in arguments.values() for _ in args]) == 1 arg = arguments['ガ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('フェイズ', 7, 17, sid3, 'dep', '') arguments = document.get_arguments(predicates[10]) assert predicates[10].core == 'フェイズ' assert len([_ for args in arguments.values() for _ in args]) == 1 arg = arguments['ノ?'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('自分', 5, 15, sid3, 'overt', '') arguments = document.get_arguments(predicates[11]) assert predicates[11].core == '使用する事ができる' assert len([_ for args in arguments.values() for _ in args]) == 5 arg = arguments['ガ'][0] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('不特定:人', 17, 'exo', '') arg = arguments['ガ'][1] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('著者', 4, 'exo', '?') arg = arguments['ガ'][2] assert isinstance(arg, SpecialArgument) assert tuple(arg) == ('読者', 3, 'exo', '?') arg = arguments['ヲ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('効果', 1, 11, sid3, 'dep', '') arg = arguments['ニ'][0] assert isinstance(arg, Argument) assert tuple(arg) == ('フェイズ', 7, 17, sid3, 'overt', '')
def show(args: argparse.Namespace): """Show the specified document in a tree format.""" reader = KyotoReader(args.path, target_cases=args.cases) for document in reader.process_all_documents(): document.draw_tree()
def test_coref1(fixture_kyoto_reader: KyotoReader): document = fixture_kyoto_reader.process_document('w201106-0000060050') entities: Dict[int, Entity] = document.entities assert len(entities) == 19 entity = entities[0] assert (entity.taigen, entity.yougen) == (None, None) assert entity.exophor == '不特定:人' mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 0 entity = entities[1] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('コイン', 0) assert mentions[0].eids == {1} entity = entities[2] assert (entity.taigen, entity.yougen) == (None, None) assert entity.exophor == '不特定:人' mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 0 entity = entities[3] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor == '読者' mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('自分', 15) assert mentions[0].eids == {14} assert mentions[0].eids_unc == {3, 4, 15} entity = entities[4] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor == '著者' mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('自分', 15) assert mentions[0].eids == {14} assert mentions[0].eids_unc == {3, 4, 15} entity = entities[5] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('トス', 1) assert mentions[0].eids == {5} entity = entities[6] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('表', 4) assert mentions[0].eids == {6} entity = entities[7] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('数', 6) assert mentions[0].eids == {7} entity = entities[8] assert (entity.taigen, entity.yougen) == (False, True) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('出た', 5) assert mentions[0].eids == {8} entity = entities[9] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('フィールド上', 7) assert mentions[0].eids == {9} entity = entities[10] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('モンスター', 8) assert mentions[0].eids == {10} entity = entities[11] assert (entity.taigen, entity.yougen) == (None, None) assert entity.exophor == '不特定:状況' mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 0 entity = entities[12] assert (entity.taigen, entity.yougen) == (False, True) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('破壊する', 9) assert mentions[0].eids == {12} entity = entities[13] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('ターン', 13) assert mentions[0].eids == {13} entity = entities[14] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('自分', 15) assert mentions[0].eids == {14} entity = entities[15] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor == '不特定:人' mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('自分', 15) assert mentions[0].eids == {14} entity = entities[16] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('フェイズ', 17) assert mentions[0].eids == {16} entity = entities[17] assert (entity.taigen, entity.yougen) == (None, None) assert entity.exophor == '不特定:人' mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 0 entity = entities[18] assert (entity.taigen, entity.yougen) == (True, False) assert entity.exophor is None mentions: List[Mention] = sorted(entity.all_mentions, key=lambda x: x.dtid) assert len(mentions) == 1 assert (mentions[0].core, mentions[0].dtid) == ('効果', 11) assert mentions[0].eids == {18}
def list_(args: argparse.Namespace): """List document IDs which specified path contains.""" reader = KyotoReader(args.path) print('\n'.join(reader.doc_ids))
def fixture_documents_gold(): reader = KyotoReader(here / 'data' / 'gold', target_cases=ALL_CASES, target_corefs=ALL_COREFS) yield reader.process_all_documents()
def list_(args: argparse.Namespace): reader = KyotoReader(args.path) print('\n'.join(reader.doc_ids))
class PASDataset(Dataset): def __init__( self, path: Union[str, Path], cases: List[str], exophors: List[str], coreference: bool, bridging: bool, max_seq_length: int, bert_path: Union[str, Path], training: bool, kc: bool, train_targets: List[str], pas_targets: List[str], n_jobs: int = -1, logger=None, gold_path: Optional[str] = None, ) -> None: self.path = Path(path) self.reader = KyotoReader(self.path, extract_nes=False, n_jobs=n_jobs) self.target_cases: List[str] = [ c for c in cases if c in self.reader.target_cases and c != 'ノ' ] self.target_exophors: List[str] = [ e for e in exophors if e in ALL_EXOPHORS ] self.coreference: bool = coreference self.bridging: bool = bridging self.relations = self.target_cases + ['ノ'] * bridging + [ '=' ] * coreference self.kc: bool = kc self.train_targets: List[str] = [ t if t != 'case' else 'dep' for t in train_targets ] # backward compatibility self.pas_targets: List[str] = pas_targets self.logger: Logger = logger or logging.getLogger(__file__) special_tokens = self.target_exophors + ['NULL'] + ( ['NA'] if coreference else []) self.special_to_index: Dict[str, int] = { token: max_seq_length - i - 1 for i, token in enumerate(reversed(special_tokens)) } self.tokenizer = BertTokenizer.from_pretrained( bert_path, do_lower_case=False, tokenize_chinese_chars=False) self.max_seq_length: int = max_seq_length self.bert_path: Path = Path(bert_path) documents = list(self.reader.process_all_documents()) if not training: self.documents: Optional[List[Document]] = documents if gold_path is not None: reader = KyotoReader(Path(gold_path), extract_nes=False, n_jobs=n_jobs) self.gold_documents = list(reader.process_all_documents()) self.examples = self._load(documents, str(path)) def _load(self, documents: List[Document], path: str) -> List[PasExample]: examples: List[PasExample] = [] load_cache: bool = ('BPA_DISABLE_CACHE' not in os.environ and 'BPA_OVERWRITE_CACHE' not in os.environ) save_cache: bool = ('BPA_DISABLE_CACHE' not in os.environ) bpa_cache_dir: Path = Path( os.environ.get('BPA_CACHE_DIR', f'/tmp/{os.environ["USER"]}/bpa_cache')) for document in tqdm(documents, desc='processing documents'): hash_ = self._hash(document, path, self.relations, self.target_exophors, self.kc, self.pas_targets, self.train_targets, str(self.bert_path)) example_cache_path = bpa_cache_dir / hash_ / f'{document.doc_id}.pkl' if example_cache_path.exists() and load_cache: with example_cache_path.open('rb') as f: example = cPickle.load(f) else: example = PasExample() example.load(document, cases=self.target_cases, exophors=self.target_exophors, coreference=self.coreference, bridging=self.bridging, relations=self.relations, kc=self.kc, pas_targets=self.pas_targets, tokenizer=self.tokenizer) if save_cache: example_cache_path.parent.mkdir(exist_ok=True, parents=True) with example_cache_path.open('wb') as f: cPickle.dump(example, f) # ignore too long document if len(example.tokens) > self.max_seq_length - len( self.special_to_index): continue examples.append(example) if len(examples) == 0: self.logger.error( 'No examples to process. ' f'Make sure there exist any documents in {self.path} and they are not too long.' ) return examples @staticmethod def _hash(document, *args) -> str: attrs = ('cases', 'corefs', 'relax_cases', 'extract_nes', 'use_pas_tag') assert set(attrs) <= set(vars(document).keys()) vars_document = {k: v for k, v in vars(document).items() if k in attrs} string = repr(sorted(vars_document)) + ''.join(repr(a) for a in args) return hashlib.md5(string.encode()).hexdigest() def _convert_example_to_feature( self, example: PasExample, ) -> InputFeatures: """Loads a data file into a list of `InputBatch`s.""" vocab_size = self.tokenizer.vocab_size max_seq_length = self.max_seq_length num_special_tokens = len(self.special_to_index) num_relations = len(self.relations) tokens = example.tokens tok_to_orig_index = example.tok_to_orig_index orig_to_tok_index = example.orig_to_tok_index arguments_set: List[List[List[int]]] = [] candidates_set: List[List[List[int]]] = [] overts_set: List[List[List[int]]] = [] deps: List[List[int]] = [] # subword loop for token, orig_index in zip(tokens, tok_to_orig_index): if orig_index is None: deps.append([0] * max_seq_length) else: ddep = example.ddeps[orig_index] # orig_index の係り先の dtid # orig_index の係り先になっている基本句を構成する全てのトークンに1が立つ deps.append([ (0 if idx is None or ddep != example.dtids[idx] else 1) for idx in tok_to_orig_index ]) deps[-1] += [0] * (max_seq_length - len(tok_to_orig_index)) # subsequent subword or [CLS] token or [SEP] token if token.startswith("##") or orig_index is None: arguments_set.append([[] for _ in range(num_relations)]) overts_set.append([[] for _ in range(num_relations)]) candidates_set.append([[] for _ in range(num_relations)]) continue arguments: List[List[int]] = [[] for _ in range(num_relations)] overts: List[List[int]] = [[] for _ in range(num_relations)] for i, (rel, arg_strings) in enumerate( example.arguments_set[orig_index].items()): for arg_string in arg_strings: # arg_string: 著者, 8%C, 15%O, 2, NULL, ... flag = None if arg_string[-2:] in ('%C', '%N', '%O'): flag = arg_string[-1] arg_string = arg_string[:-2] if arg_string in self.special_to_index: tok_index = self.special_to_index[arg_string] else: tok_index = orig_to_tok_index[int(arg_string)] if rel in self.target_cases: if arg_string in self.target_exophors and 'zero' not in self.train_targets: continue if flag == 'C': overts[i].append(tok_index) if (flag == 'C' and 'overt' not in self.train_targets) or \ (flag == 'N' and 'dep' not in self.train_targets) or \ (flag == 'O' and 'zero' not in self.train_targets): continue arguments[i].append(tok_index) arguments_set.append(arguments) overts_set.append(overts) # 助詞などに対しても特殊トークンを candidates として加える candidates: List[List[int]] = [] for rel in self.relations: if rel != '=': cands = [ orig_to_tok_index[dmid] for dmid in example.arg_candidates_set[orig_index] ] specials = self.target_exophors + ['NULL'] else: cands = [ orig_to_tok_index[dmid] for dmid in example.ment_candidates_set[orig_index] ] specials = self.target_exophors + ['NA'] cands += [ self.special_to_index[special] for special in specials ] candidates.append(cands) candidates_set.append(candidates) input_ids = self.tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. input_mask = [True] * len(input_ids) # Zero-pad up to the sequence length while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(False) arguments_set.append([[]] * num_relations) overts_set.append([[]] * num_relations) candidates_set.append([[]] * num_relations) deps.append([0] * max_seq_length) # special tokens for i in range(num_special_tokens): pos = max_seq_length - num_special_tokens + i input_ids[pos] = vocab_size + i input_mask[pos] = True assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(arguments_set) == max_seq_length assert len(overts_set) == max_seq_length assert len(candidates_set) == max_seq_length assert len(deps) == max_seq_length feature = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=[0] * max_seq_length, arguments_set=[[[int(x in args) for x in range(max_seq_length)] for args in arguments] for arguments in arguments_set], overt_mask=[[[(x in overt) for x in range(max_seq_length)] for overt in overts] for overts in overts_set], ng_token_mask=[[[(x in cands) for x in range(max_seq_length)] for cands in candidates] for candidates in candidates_set ], # False -> mask, True -> keep deps=deps, ) return feature def stat(self) -> dict: n_mentions = 0 pa: Dict[str, Union[int, dict]] = defaultdict(int) bar: Dict[str, Union[int, dict]] = defaultdict(int) cr: Dict[str, Union[int, dict]] = defaultdict(int) n_args_bar = defaultdict(int) n_args_pa = defaultdict(lambda: defaultdict(int)) for arguments in (x for example in self.examples for x in example.arguments_set): for case, args in arguments.items(): if not args: continue arg: str = args[0] if case == '=': if arg == 'NA': cr['na'] += 1 continue n_mentions += 1 if arg in self.target_exophors: cr['exo'] += 1 else: cr['ana'] += 1 else: n_args = n_args_bar if case == 'ノ' else n_args_pa[case] if arg == 'NULL': n_args['null'] += 1 continue n_args['all'] += 1 if arg in self.target_exophors: n_args['exo'] += 1 elif '%C' in arg: n_args['overt'] += 1 elif '%N' in arg: n_args['dep'] += 1 elif '%O' in arg: n_args['zero'] += 1 arguments_: List[List[str]] = list(arguments.values()) if self.coreference: if arguments_[-1]: cr['mentions_all'] += 1 if [arg for arg in arguments_[-1] if arg != 'NA']: cr['mentions_tagged'] += 1 arguments_ = arguments_[:-1] if self.bridging: if arguments_[-1]: bar['preds_all'] += 1 if [arg for arg in arguments_[-1] if arg != 'NULL']: bar['preds_tagged'] += 1 arguments_ = arguments_[:-1] if any(arguments_): pa['preds_all'] += 1 if [arg for args in arguments_ for arg in args if arg != 'NULL']: pa['preds_tagged'] += 1 n_args_pa_all = defaultdict(int) for case, ans in n_args_pa.items(): for anal, num in ans.items(): n_args_pa_all[anal] += num n_args_pa['all'] = n_args_pa_all pa['args'] = n_args_pa bar['args'] = n_args_bar cr['mentions'] = n_mentions return { 'examples': len(self.examples), 'pas': pa, 'bridging': bar, 'coreference': cr, 'sentences': sum(len(doc) for doc in self.gold_documents) if self.gold_documents else None, 'bps': sum(len(doc.bp_list()) for doc in self.gold_documents) if self.gold_documents else None, 'tokens': sum(len(example.tokens) - 2 for example in self.examples), } def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx) -> tuple: feature = self._convert_example_to_feature(self.examples[idx]) input_ids = np.array(feature.input_ids) # (seq) attention_mask = np.array(feature.input_mask) # (seq) segment_ids = np.array(feature.segment_ids) # (seq) arguments_ids = np.array(feature.arguments_set) # (seq, case, seq) overt_mask = np.array(feature.overt_mask) # (seq, case, seq) ng_token_mask = np.array(feature.ng_token_mask) # (seq, case, seq) deps = np.array(feature.deps) # (seq, seq) task = np.array(TASK_ID['pa']) # () return input_ids, attention_mask, segment_ids, ng_token_mask, arguments_ids, deps, task, overt_mask
def main(): parser = argparse.ArgumentParser() parser.add_argument( '--prediction-dir', default=None, type=str, help= 'path to directory where system output KWDLC files exist (default: None)' ) parser.add_argument( '--gold-dir', default=None, type=str, help='path to directory where gold KWDLC files exist (default: None)') parser.add_argument('--coreference', '--coref', '--cr', action='store_true', default=False, help='perform coreference resolution') parser.add_argument('--bridging', '--brg', '--bar', action='store_true', default=False, help='perform bridging anaphora resolution') parser.add_argument('--case-string', type=str, default='ガ,ヲ,ニ,ガ2', help='case strings separated by ","') parser.add_argument('--exophors', '--exo', type=str, default='著者,読者,不特定:人,不特定:物', help='exophor strings separated by ","') parser.add_argument( '--read-prediction-from-pas-tag', action='store_true', default=False, help='use <述語項構造:> tag instead of <rel > tag in prediction files') parser.add_argument( '--pas-target', choices=['', 'pred', 'noun', 'all'], default='pred', help= 'PAS analysis evaluation target (pred: verbal predicates, noun: nominal predicates)' ) parser.add_argument( '--result-html', default=None, type=str, help= 'path to html file which prediction result is exported (default: None)' ) parser.add_argument( '--result-csv', default=None, type=str, help= 'path to csv file which prediction result is exported (default: None)') args = parser.parse_args() reader_gold = KyotoReader(Path(args.gold_dir), extract_nes=False, use_pas_tag=False) reader_pred = KyotoReader( Path(args.prediction_dir), extract_nes=False, use_pas_tag=args.read_prediction_from_pas_tag, ) documents_pred = reader_pred.process_all_documents() documents_gold = reader_gold.process_all_documents() assert set(args.case_string.split(',')) <= set(CASE2YOMI.keys()) msg = '"ノ" found in case string. If you want to perform bridging anaphora resolution, specify "--bridging" ' \ 'option instead' assert 'ノ' not in args.case_string.split(','), msg scorer = Scorer(documents_pred, documents_gold, target_cases=args.case_string.split(','), target_exophors=args.exophors.split(','), coreference=args.coreference, bridging=args.bridging, pas_target=args.pas_target) result = scorer.run() if args.result_html: scorer.write_html(Path(args.result_html)) if args.result_csv: result.export_csv(args.result_csv) result.export_txt(sys.stdout)
def fixture_kyoto_reader(): reader = KyotoReader(data_dir / 'knp', target_cases=ALL_CASES, target_corefs=ALL_COREFS) yield reader