def prepare_turkcorpus_lower(): dataset = 'turkcorpus_lower' with create_directory_or_skip(get_dataset_dir(dataset)): url = 'https://github.com/cocoxu/simplification.git' output_dir = Path(tempfile.mkdtemp()) git_clone(url, output_dir) print(output_dir) print('Processing...') # Only rename files and put them in local directory architecture turkcorpus_lower_dir = output_dir / 'data/turkcorpus' print(turkcorpus_lower_dir) for (old_phase, new_phase) in [('test', 'test'), ('tune', 'valid')]: for (old_language_name, new_language_name) in [('norm', 'complex'), ('simp', 'simple')]: old_path = turkcorpus_lower_dir / f'{old_phase}.8turkers.tok.{old_language_name}' new_path = get_data_filepath('turkcorpus_lower', new_phase, new_language_name) shutil.copyfile(old_path, new_path) add_newline_at_end_of_file(new_path) shutil.move(replace_lrb_rrb_file(new_path), new_path) for i in range(8): old_path = turkcorpus_lower_dir / f'{old_phase}.8turkers.tok.turk.{i}' new_path = get_data_filepath('turkcorpus_lower', new_phase, 'simple.turk', i=i) shutil.copyfile(old_path, new_path) add_newline_at_end_of_file(new_path) shutil.move(replace_lrb_rrb_file(new_path), new_path) print('Done.') return dataset
def check_dataset(dataset): # Sanity check with evaluation dataset assert not has_lines_in_common( get_data_filepath(dataset, 'train', 'complex'), get_data_filepath('turkcorpus', 'valid', 'complex')) assert not has_lines_in_common( get_data_filepath(dataset, 'train', 'complex'), get_data_filepath('turkcorpus', 'test', 'complex'))
def create_preprocessed_dataset(dataset, preprocessors, n_jobs=1): for preprocessor in preprocessors: # Fit preprocessor on input dataset preprocessor.fit(get_data_filepath(dataset, 'train', 'complex'), get_data_filepath(dataset, 'train', 'simple')) dataset = create_preprocessed_dataset_one_preprocessor( dataset, preprocessor, n_jobs) return dataset
def prepare_turkcorpus(): dataset = 'turkcorpus' with create_directory_or_skip(get_dataset_dir(dataset)): # Import here to avoid circular imports from access.feature_extraction import get_levenshtein_similarity prepare_turkcorpus_lower() url = 'https://github.com/cocoxu/simplification.git' output_dir = Path(tempfile.mkdtemp()) git_clone(url, output_dir) print('Processing...') # Only rename files and put them in local directory architecture turkcorpus_truecased_dir = output_dir / 'data/turkcorpus/truecased' for (old_phase, new_phase) in [('test', 'test'), ('tune', 'valid')]: # (1) read the .tsv for which each line is tab separated: # `idx, complex_sentence, *turk_sentences = line.split('\t')` # (2) replace lrb and rrb, tokenize # (3) Turk sentences are shuffled for each sample so need to realign them with turkcorpus lower tsv_filepath = turkcorpus_truecased_dir / f'{old_phase}.8turkers.organized.tsv' output_complex_filepath = get_data_filepath( dataset, new_phase, 'complex') output_ref_filepaths = [ get_data_filepath(dataset, new_phase, 'simple.turk', i) for i in range(8) ] # These files will be used to reorder the shuffled ref sentences ordered_ref_filepaths = [ get_data_filepath('turkcorpus_lower', new_phase, 'simple.turk', i) for i in range(8) ] with write_lines_in_parallel([output_complex_filepath] + output_ref_filepaths) as files: input_filepaths = [tsv_filepath] + ordered_ref_filepaths for tsv_line, *ordered_ref_sentences in yield_lines_in_parallel( input_filepaths): sample_id, complex_sentence, *shuffled_ref_sentences = [ word_tokenize(normalize_quotes(replace_lrb_rrb(s))) for s in tsv_line.split('\t') ] reordered_sentences = [] for ordered_ref_sentence in ordered_ref_sentences: # Find the position of the ref_sentence in the shuffled sentences similarities = [ get_levenshtein_similarity( ordered_ref_sentence.replace(' ', ''), shuffled_ref_sentence.lower().replace(' ', '')) for shuffled_ref_sentence in shuffled_ref_sentences ] idx = np.argmax(similarities) # A few sentences have differing punctuation marks assert similarities[idx] > 0.98, \ f'{ordered_ref_sentence} != {shuffled_ref_sentences[idx].lower()} {similarities[idx]:.2f}' reordered_sentences.append( shuffled_ref_sentences.pop(idx)) assert len(shuffled_ref_sentences) == 0 assert len(reordered_sentences) == 8 files.write([complex_sentence] + reordered_sentences) return dataset
def __init__(self, vocab_size=10000, input_filepaths=None): self.vocab_size = vocab_size self.sentencepiece_model_path = VARIOUS_DIR / f'sentencepiece_model/sentencepiece_model_{self.vocab_size}.model' self.input_filepaths = input_filepaths if self.input_filepaths is None: self.input_filepaths = [ get_data_filepath('wikilarge', 'train', 'complex'), get_data_filepath('wikilarge', 'train', 'simple') ] self.learn_sentencepiece()
def prepare_wikilarge(): dataset = 'wikilarge' with create_directory_or_skip(get_dataset_dir(dataset)): url = 'https://github.com/louismartin/dress-data/raw/master/data-simplification.tar.bz2' # 解压缩得到所有解压缩文件的缓存路径 extracted_path = download_and_extract(url)[0] # Only rename files and put them in local directory architecture # 对训练集测试集验证集循环 for phase in PHASES: # 把所有src替换为complex,所有dst替换为simple for (old_language_name, new_language_name) in [('src', 'complex'), ('dst', 'simple')]: # 搜索到缓存目录下符合{*.ori.{phase}.{old_language_name}}的文件 # 例如wiki.full.aner.ori.test.dst old_path_glob = os.path.join( extracted_path, dataset, f'*.ori.{phase}.{old_language_name}') globs = glob(old_path_glob) assert len(globs) == 1 # 记录old_path old_path = globs[0] # 生成new_path new_path = get_data_filepath(dataset, phase, new_language_name) # 复制该文件至new_path shutil.copyfile(old_path, new_path) # 替换原有的--lrb--/--rrb--标签 shutil.move(replace_lrb_rrb_file(new_path), new_path) # 在末尾添加新的一行 add_newline_at_end_of_file(new_path) return dataset
def get_prediction_on_turkcorpus(simplifier, phase): source_filepath = get_data_filepath('turkcorpus', phase, 'complex') pred_filepath = get_temp_filepath() print(pred_filepath) with mute(): simplifier(source_filepath, pred_filepath) return pred_filepath
def prepare_wikilarge(): dataset = 'wikilarge' with create_directory_or_skip(get_dataset_dir(dataset)): url = 'https://github.com/louismartin/dress-data/raw/master/data-simplification.tar.bz2' extracted_path = download_and_extract(url)[0] # Only rename files and put them in local directory architecture for phase in PHASES: for (old_language_name, new_language_name) in [('src', 'complex'), ('dst', 'simple')]: old_path_glob = os.path.join( extracted_path, dataset, f'*.ori.{phase}.{old_language_name}') globs = glob(old_path_glob) assert len(globs) == 1 old_path = globs[0] new_path = get_data_filepath(dataset, phase, new_language_name) shutil.copyfile(old_path, new_path) shutil.move(replace_lrb_rrb_file(new_path), new_path) add_newline_at_end_of_file(new_path) return dataset
def sari_validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: from pathlib import Path from access.resources.paths import get_data_filepath from access.utils.helpers import read_lines from access.preprocessors import load_preprocessors, ComposedPreprocessor from easse.report import get_all_scores from fairseq.data import encoders from fairseq_cli.interactive import buffered_read, make_batches from fairseq_cli.generate import get_symbols_to_strip_from_output from fairseq.token_generation_constraints import pack_constraints, unpack_constraints import tempfile use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation task = tasks.setup_task(cfg.task) # TODO: Choose parameters for the preprocessors ? # 从pickle文件读取preprocessor # preprocessors = load_preprocessors(Path(cfg.task.data).parent) # composed_preprocessor = ComposedPreprocessor(preprocessors) # 获得turkcorpus.valid.complex的路径 complex_filepath = get_data_filepath('turkcorpus', 'valid', 'complex') # make temp dir # encoded_complex_filepath = tempfile.mkstemp()[1] # encoded_pred_filepath = tempfile.mkstemp()[1] pred_filepath = tempfile.mkstemp()[1] # use preprocessors to encode complex file # composed_preprocessor.encode_file(complex_filepath, encoded_complex_filepath) max_positions = utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ) parser = options.get_generation_parser(interactive=True) # TODO: Take args from fairseq_generate gen_args = options.parse_args_and_arch( parser, input_args=['/dummy_data', '--beam', '2']) # Initialize generator generator = task.build_generator([trainer.model], gen_args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(cfg.tokenizer) bpe = encoders.build_bpe(cfg.bpe) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x align_dict = utils.load_align_dict(cfg.generation.replace_unk) with open(pred_filepath, 'w') as f: start_id = 0 for inputs in buffered_read(complex_filepath, buffer_size=9999): results = [] for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } translations = task.inference_step(generator, [trainer.model], sample, constraints=constraints) list_constraints = [[] for _ in range(bsz)] if cfg.generation.constraints: list_constraints = [ unpack_constraints(c) for c in constraints ] for i, (id, hypos) in enumerate( zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) constraints = list_constraints[i] results.append(( start_id + id, src_tokens_i, hypos, { "constraints": constraints, }, )) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) for constraint in info["constraints"]: pass # Process top predictions for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore= get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) # detokenized hypothesis f.write(f'{detok_hypo_str}\n') if cfg.generation.print_alignment: alignment_str = " ".join([ "{}-{}".format(src, tgt) for src, tgt in alignment ]) # update running id_ counter start_id += len(inputs) # composed_preprocessor.decode_file(encoded_pred_filepath, pred_filepath) ref_filepaths = [ get_data_filepath('turkcorpus', 'valid', 'simple.turk', i) for i in range(8) ] scores = get_all_scores( read_lines(complex_filepath), read_lines(pred_filepath), [read_lines(ref_filepath) for ref_filepath in ref_filepaths]) print(f'num_updates={trainer.get_num_updates()}') print(f'ts_scores={scores}') sari = scores['SARI'] if not hasattr(trainer, 'best_sari'): trainer.best_sari = 0 if not hasattr(trainer, 'n_validations_since_best'): trainer.n_validations_since_best = 0 if sari > trainer.best_sari: trainer.best_sari = sari trainer.n_validations_since_best = 0 else: trainer.n_validations_since_best += 1 print( f'SARI did not improve for {trainer.n_validations_since_best} validations' ) # Does not work because scheduler will set it to previous value everytime # trainer.optimizer.set_lr(0.75 * trainer.optimizer.get_lr()) if trainer.n_validations_since_best >= cfg.validations_before_sari_early_stopping: print( f'Early stopping because SARI did not improve for {trainer.n_validations_since_best} validations' ) trainer.early_stopping = True def is_abort(epoch_itr, best_sari): if (epoch_itr.epoch >= 2 and best_sari < 19): return True if (epoch_itr.epoch >= 5 and best_sari < 22): return True if (epoch_itr.epoch >= 10 and best_sari < 25): return True return False # if is_abort(epoch_itr, best_sari): # print(f'Early stopping because best SARI is too low ({best_sari:.2f}) after {epoch_itr.epoch} epochs.') # # Remove the checkpoint directory as we got nothing interesting # shutil.rmtree(args.save_dir) # # TODO: Abort return [-sari]