def ensemble_translate(FLAGS): GlobalNames.USE_GPU = FLAGS.use_gpu config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) data_configs = configs['data_configs'] model_configs = configs['model_configs'] timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) valid_dataset = TextLineDataset(data_path=FLAGS.source_path, vocabulary=vocab_src) valid_iterator = DataIterator(dataset=valid_dataset, batch_size=FLAGS.batch_size, use_bucket=True, buffer_size=100000, numbering=True) INFO('Done. Elapsed time {0}'.format(timer.toc())) # ================================================================================== # # Build Model & Sampler & Validation INFO('Building model...') timer.tic() nmt_models = [] model_path = FLAGS.model_path for ii in range(len(model_path)): nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) nmt_model.eval() INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('Reloading model parameters...') timer.tic() params = load_model_parameters(model_path[ii], map_location="cpu") nmt_model.load_state_dict(params) if GlobalNames.USE_GPU: nmt_model.cuda() nmt_models.append(nmt_model) INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('Begin...') result_numbers = [] result = [] n_words = 0 timer.tic() infer_progress_bar = tqdm(total=len(valid_iterator), desc=' - (Infer) ', unit="sents") valid_iter = valid_iterator.build_generator() for batch in valid_iter: numbers, seqs_x = batch batch_size_t = len(seqs_x) x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU) with torch.no_grad(): word_ids = ensemble_beam_search(nmt_models=nmt_models, beam_size=FLAGS.beam_size, max_steps=FLAGS.max_steps, src_seqs=x, alpha=FLAGS.alpha) word_ids = word_ids.cpu().numpy().tolist() # Append result for sent_t in word_ids: sent_t = [[wid for wid in line if wid != PAD] for line in sent_t] result.append(sent_t) n_words += len(sent_t[0]) infer_progress_bar.update(batch_size_t) infer_progress_bar.close() INFO('Done. Speed: {0:.2f} words/sec'.format( n_words / (timer.toc(return_seconds=True)))) translation = [] for sent in result: samples = [] for trans in sent: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(vocab_tgt.id2token(w)) samples.append(vocab_tgt.tokenizer.detokenize(sample)) translation.append(samples) # resume the ordering origin_order = np.argsort(result_numbers).tolist() translation = [translation[ii] for ii in origin_order] keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min( FLAGS.beam_size, FLAGS.keep_n) outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)] with batch_open(outputs, 'w') as handles: for trans in translation: for i in range(keep_n): if i < len(trans): handles[i].write('%s\n' % trans[i]) else: handles[i].write('%s\n' % 'eos')
def ensemble_inference(valid_iterator, models, vocab_tgt: Vocabulary, batch_size, max_steps, beam_size=5, alpha=-1.0, rank=0, world_size=1, using_numbering_iterator=True): for model in models: model.eval() trans_in_all_beams = [[] for _ in range(beam_size)] # assert keep_n_beams <= beam_size if using_numbering_iterator: numbers = [] if rank == 0: infer_progress_bar = tqdm(total=len(valid_iterator), desc=' - (Infer) ', unit="sents") else: infer_progress_bar = None valid_iter = valid_iterator.build_generator(batch_size=batch_size) for batch in valid_iter: seq_numbers = batch[0] if using_numbering_iterator: numbers += seq_numbers seqs_x = batch[1] if infer_progress_bar is not None: infer_progress_bar.update(len(seqs_x) * world_size) x = prepare_data(seqs_x, seqs_y=None, cuda=Constants.USE_GPU) with torch.no_grad(): word_ids = ensemble_beam_search(nmt_models=models, beam_size=beam_size, max_steps=max_steps, src_seqs=x, alpha=alpha) word_ids = word_ids.cpu().numpy().tolist() # Append result for sent_t in word_ids: for ii, sent_ in enumerate(sent_t): sent_ = vocab_tgt.ids2sent(sent_) if sent_ == "": sent_ = '%s' % vocab_tgt.id2token(vocab_tgt.eos) trans_in_all_beams[ii].append(sent_) if infer_progress_bar is not None: infer_progress_bar.close() if world_size > 1: if using_numbering_iterator: numbers = dist.all_gather_py_with_shared_fs(numbers) trans_in_all_beams = [ combine_from_all_shards(trans) for trans in trans_in_all_beams ] if using_numbering_iterator: origin_order = np.argsort(numbers).tolist() trans_in_all_beams = [[trans[ii] for ii in origin_order] for trans in trans_in_all_beams] return trans_in_all_beams