def main(args): # we should not do this! ''' if args.max_tokens is None: args.max_tokens = 6000 ''' utils.xpprint(args) if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) utils.xprintln('setup task done!') # Load dataset splits load_dataset_splits(args, task, ['train']) valid_dataset = args.valid_subset.split(',') load_dataset_splits(args, task, valid_dataset, shuffle=False) utils.xprintln('load dataset done!') if args.task.startswith('extractive_summarization'): if distributed_utils.is_master(args): from sum_eval import MultiProcSumEval sum_eval_pool = MultiProcSumEval(args.ncpu_eval) sum_valid_pool_params = dict( article_file=args.raw_valid + '.article', summary_file=args.raw_valid + '.summary', entity_map_file=None, length=-1, eval_type='predict', topk=args.topk_sent_eval, rerank=False, with_m=False, cmd='-a -c 95 -m -n 4 -w 1.2', trigram_block=args.trigram_block, ) sum_test_pool_params = dict( article_file=args.raw_test + '.article', summary_file=args.raw_test + '.summary', entity_map_file=None, length=-1, eval_type='predict', topk=args.topk_sent_eval, rerank=False, with_m=False, cmd='-a -c 95 -m -n 4 -w 1.2', trigram_block=args.trigram_block, ) sum_pool_params = dict(valid=sum_valid_pool_params, test=sum_test_pool_params) def make_params(default_dict, result_file, out_rouge_file, rerank=False, with_m=False): para_dict = dict(default_dict) para_dict['result_file'] = result_file para_dict['out_rouge_file'] = out_rouge_file para_dict['rerank'] = rerank para_dict['with_m'] = with_m return para_dict # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {}'.format( sum(p.numel() for p in model.parameters()))) # print(model) import sys sys.stdout.flush() # if summarization try to load pretrained model # if args.task.startswith('extractive_summarization') or args.task == 'pretrain_document_modeling': # # assume this is a single GPU program if args.init_from_pretrained_doc_model: task.load_pretrained_model(model, args.pretrained_doc_model_path) sys.stdout.flush() # Build trainer trainer = Trainer(args, task, model, criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Initialize dataloader max_positions = trainer.get_model().max_positions() epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=False) # Load the latest checkpoint if one is available # load_checkpoint(args, trainer, epoch_itr) # make sure training from a different checkpoint will use different random seed cur_dataset = task.dataset('train') if hasattr(cur_dataset, 'rng'): print('epoch ', epoch_itr.epoch) cur_dataset.rng = numpy.random.RandomState(args.seed + epoch_itr.epoch) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_losses = [None] valid_subsets = args.valid_subset.split(',') for alpha in range(10, 9, -1): # train for one epoch # train(args, trainer, task, epoch_itr) epoch_itr.next_epoch_itr() if epoch_itr.epoch % args.validate_interval == 0: if args.task.startswith('extractive_summarization'): if distributed_utils.is_master(args): validate_metric(args, trainer, task, epoch_itr, valid_subsets)
def main(args): assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' # if args.save_path is not None: # if check_file_exists(args): # return import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 # print(args) utils.xpprint(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) models, _model_args = utils.load_ensemble_for_inference( args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), # *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(args) # Generate and compute BLEU score if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True if args.isRoberta: from pytorch_transformers import RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained('roberta-base') else: tokenizer = None with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) target_tokens = None if has_target: target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) if not args.quiet: if src_dict is not None: if not args.isRoberta: print('S-{}\t{}'.format(sample_id, src_str)) else: src_text = ''.join(src_str.strip().split()) src_out = tokenizer.convert_tokens_to_string( src_text) print('S-{}\t{}'.format(sample_id, src_out)) if has_target: if not args.isRoberta: print('T-{}\t{}'.format(sample_id, target_str)) else: tgt_text = ''.join(target_str.strip().split()) tgt_out = tokenizer.convert_tokens_to_string( tgt_text) print('T-{}\t{}'.format(sample_id, tgt_out)) # Process top predictions for i, hypo in enumerate( hypos[i][:min(len(hypos), args.nbest)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: if not args.isRoberta: print('H-{}\t{}\t{}'.format( sample_id, hypo['score'], hypo_str)) else: hypo_text = ''.join(hypo_str.strip().split()) hypo_out = tokenizer.convert_tokens_to_string( hypo_text) print('H-{}\t{}\t{}'.format( sample_id, hypo['score'], hypo_out)) print('P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )))) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join( map(lambda x: str(utils.item(x)), alignment)))) # Score only the top hypothesis if has_target and i == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] print( '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) return scorer
def main(args): from fairseq import utils utils.xpprint(args) os.makedirs(args.destdir, exist_ok=True) target = not args.only_source def build_dictionary(filenames): d = dictionary.Dictionary() for filename in filenames: Tokenizer.add_file_to_dictionary(filename, d, tokenize_line) return d def build_dictionary_label(filenames): d = flexible_dictionary.FlexibleDictionary([('PAD', '<pad>')]) for filename in filenames: Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, append_eos=False) return d def train_path(lang): return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '') def file_name(prefix, lang): fname = prefix if lang is not None: fname += f'.{lang}' return fname def dest_path(prefix, lang): return os.path.join(args.destdir, file_name(prefix, lang)) def dict_path(lang): return dest_path('dict', lang) + '.txt' def dataset_dest_path(output_prefix, lang, extension): base = f'{args.destdir}/{output_prefix}' lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else '' return f'{base}{lang_part}.{extension}' assert args.srcdict is not None, 'where is the Bert Dict!' if args.srcdict: src_dict = gpt2_dictionary.GPT2Dictionary.load(args.srcdict) src_dict.save(dict_path(args.source_lang)) print('load bert dict from {} | size {}'.format(args.srcdict, len(src_dict))) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary([train_path(args.source_lang)]) if target: if args.tgtdict: tgt_dict = flexible_dictionary.FlexibleDictionary.load(args.tgtdict) print('load label dict from {} | size {}'.format(args.tgtdict, len(tgt_dict))) else: assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" tgt_dict = build_dictionary_label([train_path(args.target_lang)]) print('build target dict from {} done'.format(train_path(args.target_lang))) src_dict.save(dict_path(args.source_lang)) if target: if not args.joined_dictionary: tgt_dict.finalize( threshold=args.thresholdtgt, nwords=args.nwordstgt, padding_factor=1, ) tgt_dict.save(dict_path(args.target_lang)) def make_binary_dataset(input_prefix, output_prefix, lang, append_eos=False): if lang == args.target_lang: dict = flexible_dictionary.FlexibleDictionary.load(dict_path(lang)) else: # dict = bert_dictionary.BertDictionary.load(dict_path(lang)) dict = gpt2_dictionary.GPT2Dictionary.load(dict_path(lang)) print('| [{}] Dictionary: {} types | {} types (for real)'.format(lang, len(dict) - 1, len(dict))) ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin')) def consumer(tensor): ds.add_item(tensor) input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '') if lang == args.target_lang: res = Tokenizer.binarize(input_file, dict, consumer, append_eos=append_eos) print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( lang, input_file, res['nseq'], res['ntok'], 100 * res['nunk'] / res['ntok'], dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>')) else: # read article # from pytorch_pretrained_bert.tokenization import BertTokenizer # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) from pytorch_transformers import RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained('roberta-base') def penn_token2orig_token(sent): # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB- ''' penn2orig = {"``":'"', "''": '"', "-LRB-": '(', "-RRB-": ')', "-LSB-":'[', "-RSB-":']', "-LCB-":'{', "-RCB-":'}'} ''' penn2orig = {"-LRB-": '(', "-RRB-": ')', "-LSB-": '[', "-RSB-": ']', "-LCB-": '{', "-RCB-": '}', "-lrb-": '(', "-rrb-": ')', "-lsb-": '[', "-rsb-": ']', "-lcb-": '{', "-rcb-": '}',} words = sent.strip().split() words = [wd if not wd in penn2orig else penn2orig[wd] for wd in words] return ' '.join(words) num_token, num_unk_token = 0, 0 num_seq = 0 skip_line = 0 for line in open(input_file, encoding='utf8'): sents = line.strip().split('<S_SEP>') sents = sents[0:args.max_num_sentences] sents = [' '.join(sent.strip().split()[0:args.max_num_words]) for sent in sents] # print(sents) sents = [tokenizer.tokenize(penn_token2orig_token(sent)) for sent in sents] article_wids = [] for i, sent in enumerate(sents): # sometimes there are too many tokens MAXLEN = 500 if len(sent) > MAXLEN: # sent = sent[0:MAXLEN] print(' '.join(sent)) skip_line += 1 print(skip_line) continue if i != 0: article_wids.append( dict.sep_index ) wids = tokenizer.convert_tokens_to_ids(sent) # wids_vocab = [dict.index(word) for word in sent] # assert wids == wids_vocab, 'word indices should be the same!' article_wids.extend(wids) for wid in wids: if wid == dict.unk_index: num_unk_token += 1 num_token += 1 num_seq += 1 tensor = torch.IntTensor(article_wids) # print( dict.string_complete(tensor) ) ds.add_item(tensor) print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( lang, input_file, num_seq, num_token, 100 * num_unk_token / num_token, dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>')) ds.finalize(dataset_dest_path(output_prefix, lang, 'idx')) def make_dataset(input_prefix, output_prefix, lang): if args.output_format == 'binary': make_binary_dataset(input_prefix, output_prefix, lang) elif args.output_format == 'raw': # Copy original text file to destination folder output_text_file = dest_path( output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang), lang, ) shutil.copyfile(file_name(input_prefix, lang), output_text_file) def make_all(lang): if args.trainpref: make_dataset(args.trainpref, 'train', lang) if args.validpref: for k, validpref in enumerate(args.validpref.split(',')): outprefix = 'valid{}'.format(k) if k > 0 else 'valid' make_dataset(validpref, outprefix, lang) if args.testpref: for k, testpref in enumerate(args.testpref.split(',')): outprefix = 'test{}'.format(k) if k > 0 else 'test' make_dataset(testpref, outprefix, lang) make_all(args.source_lang) if target: make_all(args.target_lang) print('| Wrote preprocessed data to {}'.format(args.destdir))
def main(args): from fairseq import utils utils.xpprint(args) import_user_module(args) print(args) os.makedirs(args.destdir, exist_ok=True) target = not args.only_source task = tasks.get_task(args.task) def train_path(lang): return "{}{}".format(args.trainpref, ("." + lang) if lang else "") def file_name(prefix, lang): fname = prefix if lang is not None: fname += ".{lang}".format(lang=lang) return fname def dest_path(prefix, lang): return os.path.join(args.destdir, file_name(prefix, lang)) def dict_path(lang): return dest_path("dict", lang) + ".txt" def build_dictionary(filenames, src=False, tgt=False): assert src ^ tgt return task.build_dictionary( filenames, workers=args.workers, threshold=args.thresholdsrc if src else args.thresholdtgt, nwords=args.nwordssrc if src else args.nwordstgt, padding_factor=args.padding_factor, ) if not args.srcdict and os.path.exists(dict_path(args.source_lang)): raise FileExistsError(dict_path(args.source_lang)) if target and not args.tgtdict and os.path.exists( dict_path(args.target_lang)): raise FileExistsError(dict_path(args.target_lang)) if args.joined_dictionary: assert not args.srcdict or not args.tgtdict, \ "cannot use both --srcdict and --tgtdict with --joined-dictionary" if args.srcdict: src_dict = task.load_dictionary(args.srcdict) elif args.tgtdict: src_dict = task.load_dictionary(args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary( { train_path(lang) for lang in [args.source_lang, args.target_lang] }, src=True) tgt_dict = src_dict else: if args.srcdict: src_dict = xlnet_dictionary.XLNetDictionary.load(args.srcdict) print('load xlnet dict from {} | size {}'.format( args.srcdict, len(src_dict))) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary([train_path(args.source_lang)], src=True) if target: if args.tgtdict: tgt_dict = xlnet_dictionary.XLNetDictionary.load(args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True) else: tgt_dict = None src_dict.save(dict_path(args.source_lang)) if target and tgt_dict is not None: tgt_dict.save(dict_path(args.target_lang)) def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1)) print('input_prefix', input_prefix) print(dict_path(lang)) dict = xlnet_dictionary.XLNetDictionary.load(dict_path(lang)) input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "") from pytorch_transformers import XLNetConfig, XLNetTokenizer import torch tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') def penn_token2orig_token(sent): # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB- penn2orig = { "``": '"', "''": '"', "-LRB-": '(', "-RRB-": ')', "-LSB-": '[', "-RSB-": ']', "-LCB-": '{', "-RCB-": '}' } words = sent.strip().split() words = [ wd if not wd in penn2orig else penn2orig[wd] for wd in words ] return ' '.join(words) num_token, num_unk_token = 0, 0 num_seq = 0 ds = indexed_dataset.IndexedDatasetBuilder( dataset_dest_file(args, output_prefix, lang, "bin")) for line in open(input_file, encoding='utf8'): sents = line.strip().split('<S_SEP>') sents = [ tokenizer.tokenize(penn_token2orig_token(sent)) for sent in sents ] article_wids = [] for i, sent in enumerate(sents): if i != 0: article_wids.append(dict.sep_index) wids = tokenizer.convert_tokens_to_ids(sent) # wids_vocab = [dict.index(word) for word in sent] # assert wids == wids_vocab, 'word indices should be the same!' article_wids.extend(wids) for wid in wids: if wid == dict.unk_index: num_unk_token += 1 num_token += 1 num_seq += 1 tensor = torch.IntTensor(article_wids) # print( dict.string_complete(tensor) ) ds.add_item(tensor) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( lang, input_file, num_seq, num_token, 100 * num_unk_token / num_token, dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>')) # # n_seq_tok = [0, 0] # replaced = Counter() # # def merge_result(worker_result): # replaced.update(worker_result["replaced"]) # n_seq_tok[0] += worker_result["nseq"] # n_seq_tok[1] += worker_result["ntok"] # # input_file = "{}{}".format( # input_prefix, ("." + lang) if lang is not None else "" # ) # offsets = Binarizer.find_offsets(input_file, num_workers) # pool = None # if num_workers > 1: # pool = Pool(processes=num_workers - 1) # for worker_id in range(1, num_workers): # prefix = "{}{}".format(output_prefix, worker_id) # pool.apply_async( # binarize, # ( # args, # input_file, # vocab, # prefix, # lang, # offsets[worker_id], # offsets[worker_id + 1] # ), # callback=merge_result # ) # pool.close() # # ds = indexed_dataset.IndexedDatasetBuilder( # dataset_dest_file(args, output_prefix, lang, "bin") # ) # merge_result( # Binarizer.binarize( # input_file, vocab, lambda t: ds.add_item(t), # offset=0, end=offsets[1] # ) # ) # if num_workers > 1: # pool.join() # for worker_id in range(1, num_workers): # prefix = "{}{}".format(output_prefix, worker_id) # temp_file_path = dataset_dest_prefix(args, prefix, lang) # ds.merge_file_(temp_file_path) # os.remove(indexed_dataset.data_file_path(temp_file_path)) # os.remove(indexed_dataset.index_file_path(temp_file_path)) # # ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) # # print( # "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format( # lang, # input_file, # n_seq_tok[0], # n_seq_tok[1], # 100 * sum(replaced.values()) / n_seq_tok[1], # vocab.unk_word, # ) # ) def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.output_format == "binary": make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) elif args.output_format == "raw": # Copy original text file to destination folder output_text_file = dest_path( output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), lang, ) shutil.copyfile(file_name(input_prefix, lang), output_text_file) def make_all(lang, vocab): if args.trainpref: print(args.trainpref, lang) make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers) if args.validpref: for k, validpref in enumerate(args.validpref.split(",")): outprefix = "valid{}".format(k) if k > 0 else "valid" make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers) if args.testpref: for k, testpref in enumerate(args.testpref.split(",")): outprefix = "test{}".format(k) if k > 0 else "test" make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) make_all(args.source_lang, src_dict) if target: make_all(args.target_lang, tgt_dict) print("| Wrote preprocessed data to {}".format(args.destdir)) if args.alignfile: assert args.trainpref, "--trainpref must be set if --alignfile is specified" src_file_name = train_path(args.source_lang) tgt_file_name = train_path(args.target_lang) freq_map = {} with open(args.alignfile, "r", encoding='utf-8') as align_file: with open(src_file_name, "r", encoding='utf-8') as src_file: with open(tgt_file_name, "r", encoding='utf-8') as tgt_file: for a, s, t in zip_longest(align_file, src_file, tgt_file): si = src_dict.encode_line(s, add_if_not_exist=False) ti = tgt_dict.encode_line(t, add_if_not_exist=False) ai = list(map(lambda x: tuple(x.split("-")), a.split())) for sai, tai in ai: srcidx = si[int(sai)] tgtidx = ti[int(tai)] if srcidx != src_dict.unk( ) and tgtidx != tgt_dict.unk(): assert srcidx != src_dict.pad() assert srcidx != src_dict.eos() assert tgtidx != tgt_dict.pad() assert tgtidx != tgt_dict.eos() if srcidx not in freq_map: freq_map[srcidx] = {} if tgtidx not in freq_map[srcidx]: freq_map[srcidx][tgtidx] = 1 else: freq_map[srcidx][tgtidx] += 1 align_dict = {} for srcidx in freq_map.keys(): align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) with open(os.path.join( args.destdir, "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), ), "w", encoding='utf-8') as f: for k, v in align_dict.items(): print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Print args utils.xpprint(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, shuffle=False, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator if hasattr(args, 'init_from_pretrained_doc_model' ) and args.init_from_pretrained_doc_model: import os if not os.path.exists(os.path.join(args.save_dir, "checkpoint_last.pt")): args.restore_file = args.pretrained_doc_model_path args.reset_optimizer, args.reset_dataloader, args.resetmeters = True, True, True extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) sys.stdout.flush() # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() if isinstance(lr, list): lr = min(lr) train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None)) and trainer.get_num_updates() < max_update): # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) if isinstance(lr, list): lr = min(lr) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) reload_dataset = ':' in getattr(args, 'data', '') # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def main(args, init_distributed=False): import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 6000 # print(args) utils.xpprint(args) if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load dataset splits load_dataset_splits(task, ['train', 'valid']) # Initialize distributed training (after data loading) if init_distributed: import socket args.distributed_rank = distributed_utils.distributed_init(args) print('| initialized host {} as rank {}'.format( socket.gethostname(), args.distributed_rank)) args.init_distributed = init_distributed # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) import sys sys.stdout.flush() # Make a dummy batch to (i) warm the caching allocator and (ii) as a # placeholder DistributedDataParallel when there's an uneven number of # batches per worker. max_positions = utils.resolve_max_positions( task.max_positions(), model.max_positions(), ) dummy_batch = task.dataset('train').get_dummy_batch( args.max_tokens, max_positions, batch_size=args.max_sentences) oom_batch = task.dataset('train').get_dummy_batch(1, max_positions) # Build trainer trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Initialize dataloader epoch_itr = task.get_batch_iterator( dataset=task.dataset(args.train_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, ignore_invalid_inputs=True, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ) # print(trainer.get_model().decoder.layers[11].output.LayerNorm.weight.data) # Load the latest checkpoint if one is available if not load_checkpoint(args, trainer, epoch_itr): if args.task == 'abstractive_summarization_bert' or args.task == 'abstractive_summarization_roberta': if args.init_from_pretrained_model and args.pretrained_model_path: task.load_pretrained_model(model, args.pretrained_model_path) elif hasattr( args, 'roberta_decoder') and args.roberta_decoder and hasattr( args, 'roberta_decoder_initialization' ) and args.roberta_decoder_initialization: model.initilize_roberta_decoder() trainer.dummy_train_step([dummy_batch]) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() if args.sep_optim: dec_lr = trainer.get_dec_lr() train_meter = StopwatchMeter() train_meter.start() valid_losses = [None] valid_subsets = args.valid_subset.split(',') while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates( ) < max_update: # train for one epoch train(args, trainer, task, epoch_itr) if epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) if args.sep_optim: dec_lr = trainer.dec_lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def main(args): from fairseq import utils utils.xpprint(args) import_user_module(args) print(args) os.makedirs(args.destdir, exist_ok=True) target = not args.only_source task = tasks.get_task(args.task) def train_path(lang): return "{}{}".format(args.trainpref, ("." + lang) if lang else "") def file_name(prefix, lang): fname = prefix if lang is not None: fname += ".{lang}".format(lang=lang) return fname def dest_path(prefix, lang): return os.path.join(args.destdir, file_name(prefix, lang)) def dict_path(lang): return dest_path("dict", lang) + ".txt" def build_dictionary(filenames, src=False, tgt=False): assert src ^ tgt return task.build_dictionary( filenames, workers=args.workers, threshold=args.thresholdsrc if src else args.thresholdtgt, nwords=args.nwordssrc if src else args.nwordstgt, padding_factor=args.padding_factor, ) if not args.srcdict and os.path.exists(dict_path(args.source_lang)): raise FileExistsError(dict_path(args.source_lang)) if target and not args.tgtdict and os.path.exists( dict_path(args.target_lang)): raise FileExistsError(dict_path(args.target_lang)) if args.joined_dictionary: assert not args.srcdict or not args.tgtdict, \ "cannot use both --srcdict and --tgtdict with --joined-dictionary" if args.srcdict: src_dict = task.load_dictionary(args.srcdict) elif args.tgtdict: src_dict = task.load_dictionary(args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary( { train_path(lang) for lang in [args.source_lang, args.target_lang] }, src=True) tgt_dict = src_dict else: if args.srcdict: src_dict = roberta_dictionary.RobertaDictionary.load_json( args.srcdict) # src_dict.save('roberta-vocab/roberta-base-vocab.txt') print('load bert dict from {} | size {}'.format( args.srcdict, len(src_dict))) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary([train_path(args.source_lang)], src=True) if target: if args.tgtdict: tgt_dict = roberta_dictionary.RobertaDictionary.load_json( args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True) else: tgt_dict = None src_dict.save(dict_path(args.source_lang)) if target and tgt_dict is not None: tgt_dict.save(dict_path(args.target_lang)) def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1)) print('input_prefix', input_prefix) print(dict_path(lang)) dict = roberta_dictionary.RobertaDictionary.load(dict_path(lang)) input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "") from pytorch_transformers import RobertaTokenizer import torch tokenizer = RobertaTokenizer.from_pretrained('roberta-base') def penn_token2orig_token(sent): # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB- penn2orig = { "``": '"', "''": '"', "-LRB-": '(', "-RRB-": ')', "-LSB-": '[', "-RSB-": ']', "-LCB-": '{', "-RCB-": '}' } words = sent.strip().split() words = [ wd if not wd in penn2orig else penn2orig[wd] for wd in words ] return ' '.join(words) num_token, num_unk_token = 0, 0 num_seq = 0 ds = indexed_dataset.IndexedDatasetBuilder( dataset_dest_file(args, output_prefix, lang, "bin")) output_ds = indexed_dataset.IndexedDatasetBuilder( dataset_dest_file(args, output_prefix, 'article_next', "bin")) truncated_number = 512 output_length = 256 CLS_TOKEN = '<s>' SEP_TOKEN = '</s>' for line in open(input_file, encoding='utf8'): sents = line.strip().split('<S_SEP>') sents = [ tokenizer.tokenize(penn_token2orig_token(sent)) for sent in sents ] article_toks = [] for i, sent in enumerate(sents): if i != 0: article_toks.append(SEP_TOKEN) article_toks.extend(sent) article_segments = [] output_segments = [] tmp_seg = [] for i, tok in enumerate(article_toks): if len(tmp_seg) == 0: tmp_seg.append(CLS_TOKEN) tmp_seg.append(tok) if tok == SEP_TOKEN: tmp_seg.append(tok) if len(tmp_seg) >= truncated_number: tmp_seg = tmp_seg[:truncated_number] if tmp_seg[-1] != SEP_TOKEN: tmp_seg[-1] = SEP_TOKEN tmp_output = article_toks[ i + 1:min(i + 1 + output_length, len(article_toks))] if len(tmp_output) < 0.3 * output_length: break article_segments.append( tokenizer.convert_tokens_to_ids(tmp_seg)) output_segments.append( tokenizer.convert_tokens_to_ids(tmp_output)) tmp_seg = [] assert len(article_segments) == len(output_segments) for i in range(len(article_segments)): assert len(article_segments[i]) <= truncated_number assert len(output_segments[i]) <= output_length and len( output_segments[i]) >= 0.3 * output_length tensor = torch.IntTensor(article_segments[i]) ds.add_item(tensor) output_tensor = torch.IntTensor(output_segments[i]) output_ds.add_item(output_tensor) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) output_ds.finalize( dataset_dest_file(args, output_prefix, 'article_next', "idx")) print('done!') # print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( # lang, input_file, num_seq, num_token, # 100 * num_unk_token / num_token, dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>')) def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.output_format == "binary": make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) elif args.output_format == "raw": # Copy original text file to destination folder output_text_file = dest_path( output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), lang, ) shutil.copyfile(file_name(input_prefix, lang), output_text_file) def make_all(lang, vocab): if args.trainpref: print(args.trainpref, lang) make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers) if args.validpref: for k, validpref in enumerate(args.validpref.split(",")): outprefix = "valid{}".format(k) if k > 0 else "valid" make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers) # if args.testpref: # for k, testpref in enumerate(args.testpref.split(",")): # outprefix = "test{}".format(k) if k > 0 else "test" # make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) make_all(args.source_lang, src_dict) # if target: # make_all(args.target_lang, tgt_dict) print("| Wrote preprocessed data to {}".format(args.destdir)) if args.alignfile: assert args.trainpref, "--trainpref must be set if --alignfile is specified" src_file_name = train_path(args.source_lang) tgt_file_name = train_path(args.target_lang) freq_map = {} with open(args.alignfile, "r", encoding='utf-8') as align_file: with open(src_file_name, "r", encoding='utf-8') as src_file: with open(tgt_file_name, "r", encoding='utf-8') as tgt_file: for a, s, t in zip_longest(align_file, src_file, tgt_file): si = src_dict.encode_line(s, add_if_not_exist=False) ti = tgt_dict.encode_line(t, add_if_not_exist=False) ai = list(map(lambda x: tuple(x.split("-")), a.split())) for sai, tai in ai: srcidx = si[int(sai)] tgtidx = ti[int(tai)] if srcidx != src_dict.unk( ) and tgtidx != tgt_dict.unk(): assert srcidx != src_dict.pad() assert srcidx != src_dict.eos() assert tgtidx != tgt_dict.pad() assert tgtidx != tgt_dict.eos() if srcidx not in freq_map: freq_map[srcidx] = {} if tgtidx not in freq_map[srcidx]: freq_map[srcidx][tgtidx] = 1 else: freq_map[srcidx][tgtidx] += 1 align_dict = {} for srcidx in freq_map.keys(): align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) with open(os.path.join( args.destdir, "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), ), "w", encoding='utf-8') as f: for k, v in align_dict.items(): print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def main(args): from fairseq import utils utils.xpprint(args) import_user_module(args) print(args) os.makedirs(args.destdir, exist_ok=True) target = not args.only_source task = tasks.get_task(args.task) def train_path(lang): return "{}{}".format(args.trainpref, ("." + lang) if lang else "") def file_name(prefix, lang): fname = prefix if lang is not None: fname += ".{lang}".format(lang=lang) return fname def dest_path(prefix, lang): return os.path.join(args.destdir, file_name(prefix, lang)) def dict_path(lang): return dest_path("dict", lang) + ".txt" def build_dictionary(filenames, src=False, tgt=False): assert src ^ tgt return task.build_dictionary( filenames, workers=args.workers, threshold=args.thresholdsrc if src else args.thresholdtgt, nwords=args.nwordssrc if src else args.nwordstgt, padding_factor=args.padding_factor, ) if not args.srcdict and os.path.exists(dict_path(args.source_lang)): raise FileExistsError(dict_path(args.source_lang)) if target and not args.tgtdict and os.path.exists( dict_path(args.target_lang)): raise FileExistsError(dict_path(args.target_lang)) if args.joined_dictionary: assert not args.srcdict or not args.tgtdict, \ "cannot use both --srcdict and --tgtdict with --joined-dictionary" if args.srcdict: src_dict = bert_dictionary.BertDictionary.load(args.srcdict) elif args.tgtdict: src_dict = bert_dictionary.BertDictionary.load(args.srcdict) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary( { train_path(lang) for lang in [args.source_lang, args.target_lang] }, src=True) tgt_dict = src_dict else: if args.srcdict: src_dict = bert_dictionary.BertDictionary.load(args.srcdict) print('load bert dict from {} | size {}'.format( args.srcdict, len(src_dict))) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary([train_path(args.source_lang)], src=True) if target: if args.tgtdict: tgt_dict = bert_dictionary.BertDictionary.load(args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True) else: tgt_dict = None src_dict.save(dict_path(args.source_lang)) if target and tgt_dict is not None: tgt_dict.save(dict_path(args.target_lang)) def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1)) print('input_prefix', input_prefix) print(dict_path(lang)) dict = bert_dictionary.BertDictionary.load(dict_path(lang)) input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "") from pytorch_transformers import BertTokenizer import torch tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') def penn_token2orig_token(sent): # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB- penn2orig = { "``": '"', "''": '"', "-LRB-": '(', "-RRB-": ')', "-LSB-": '[', "-RSB-": ']', "-LCB-": '{', "-RCB-": '}' } words = sent.strip().split() words = [ wd if not wd in penn2orig else penn2orig[wd] for wd in words ] return ' '.join(words) num_token, num_unk_token = 0, 0 num_seq = 0 ds = indexed_dataset.IndexedDatasetBuilder( dataset_dest_file(args, output_prefix, lang, "bin")) output_ds = indexed_dataset.IndexedDatasetBuilder( dataset_dest_file(args, output_prefix, 'article_next', "bin")) article_input = 511 article_next = 256 BERT_CLS_ID = tokenizer.convert_tokens_to_ids([BERT_CLS])[0] BERT_SEP_ID = tokenizer.convert_tokens_to_ids([BERT_SEP])[0] for line in open(input_file, encoding='utf8'): sents = line.strip().split('<S_SEP>') sents = [ tokenizer.tokenize(penn_token2orig_token(sent)) for sent in sents ] article_wids = [] for i, sent in enumerate(sents): if i != 0: article_wids.append(dict.sep_index) if len(sent) > article_input: wids = [] temp_sent = [ sent[x:x + article_input] for x in range(0, len(sent), article_input) ] for se in temp_sent: se_ids = tokenizer.convert_tokens_to_ids(se) wids.extend(se_ids) else: wids = tokenizer.convert_tokens_to_ids(sent) # wids_vocab = [dict.index(word) for word in sent] # assert wids == wids_vocab, 'word indices should be the same!' article_wids.extend(wids) for wid in wids: if wid == dict.unk_index: num_unk_token += 1 num_token += 1 article_segments = [ article_wids[x:x + article_input] for x in range(0, len(article_wids), article_input) ] cur_position = 0 for i in range(len(article_segments)): article_seq = article_segments[i] cur_position += len(article_seq) output_seg = article_wids[ cur_position:min(len(article_wids), cur_position + article_next)] if len(output_seg) < 0.3 * article_next: continue num_seq += 1 if len(article_seq) > article_input: print('lang: %s, token len: %d, truncated len: %d' % (lang, len(article_seq), article_input)) if lang == 'article': if article_seq[-1] != BERT_SEP_ID: if article_seq[-2] != BERT_SEP_ID: article_seq[-1] = BERT_SEP_ID article_seq = [BERT_CLS_ID] + article_seq if len(output_seg) > article_next: print( 'lang: article_next, token len: %d, truncated len: %d' % (len(output_seg), article_next)) tensor = torch.IntTensor(article_seq) ds.add_item(tensor) output_tensor = torch.IntTensor(output_seg) output_ds.add_item(output_tensor) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) output_ds.finalize( dataset_dest_file(args, output_prefix, 'article_next', "idx")) print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( lang, input_file, num_seq, num_token, 100 * num_unk_token / num_token, dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>')) def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.output_format == "binary": make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) elif args.output_format == "raw": # Copy original text file to destination folder output_text_file = dest_path( output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), lang, ) shutil.copyfile(file_name(input_prefix, lang), output_text_file) def make_all(lang, vocab): if args.trainpref: print(args.trainpref, lang) make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers) if args.validpref: for k, validpref in enumerate(args.validpref.split(",")): outprefix = "valid{}".format(k) if k > 0 else "valid" make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers) if args.testpref: for k, testpref in enumerate(args.testpref.split(",")): outprefix = "test{}".format(k) if k > 0 else "test" make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) make_all(args.source_lang, src_dict) # if target: # make_all(args.target_lang, tgt_dict) print("| Wrote preprocessed data to {}".format(args.destdir)) if args.alignfile: assert args.trainpref, "--trainpref must be set if --alignfile is specified" src_file_name = train_path(args.source_lang) tgt_file_name = train_path(args.target_lang) freq_map = {} with open(args.alignfile, "r", encoding='utf-8') as align_file: with open(src_file_name, "r", encoding='utf-8') as src_file: with open(tgt_file_name, "r", encoding='utf-8') as tgt_file: for a, s, t in zip_longest(align_file, src_file, tgt_file): si = src_dict.encode_line(s, add_if_not_exist=False) ti = tgt_dict.encode_line(t, add_if_not_exist=False) ai = list(map(lambda x: tuple(x.split("-")), a.split())) for sai, tai in ai: srcidx = si[int(sai)] tgtidx = ti[int(tai)] if srcidx != src_dict.unk( ) and tgtidx != tgt_dict.unk(): assert srcidx != src_dict.pad() assert srcidx != src_dict.eos() assert tgtidx != tgt_dict.pad() assert tgtidx != tgt_dict.eos() if srcidx not in freq_map: freq_map[srcidx] = {} if tgtidx not in freq_map[srcidx]: freq_map[srcidx][tgtidx] = 1 else: freq_map[srcidx][tgtidx] += 1 align_dict = {} for srcidx in freq_map.keys(): align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) with open(os.path.join( args.destdir, "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), ), "w", encoding='utf-8') as f: for k, v in align_dict.items(): print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)