def ensemble_inference(valid_iterator, models, vocab_tgt: Vocabulary, batch_size, max_steps, beam_size=5, alpha=-1.0, rank=0, world_size=1, using_numbering_iterator=True): for model in models: model.eval() trans_in_all_beams = [[] for _ in range(beam_size)] # assert keep_n_beams <= beam_size if using_numbering_iterator: numbers = [] if rank == 0: infer_progress_bar = tqdm(total=len(valid_iterator), desc=' - (Infer) ', unit="sents") else: infer_progress_bar = None valid_iter = valid_iterator.build_generator(batch_size=batch_size) for batch in valid_iter: seq_numbers = batch[0] if using_numbering_iterator: numbers += seq_numbers seqs_x = batch[1] if infer_progress_bar is not None: infer_progress_bar.update(len(seqs_x) * world_size) x = prepare_data(seqs_x, seqs_y=None, cuda=Constants.USE_GPU) with torch.no_grad(): word_ids = ensemble_beam_search(nmt_models=models, beam_size=beam_size, max_steps=max_steps, src_seqs=x, alpha=alpha) word_ids = word_ids.cpu().numpy().tolist() # Append result for sent_t in word_ids: for ii, sent_ in enumerate(sent_t): sent_ = vocab_tgt.ids2sent(sent_) if sent_ == "": sent_ = '%s' % vocab_tgt.id2token(vocab_tgt.eos) trans_in_all_beams[ii].append(sent_) if infer_progress_bar is not None: infer_progress_bar.close() if world_size > 1: if using_numbering_iterator: numbers = dist.all_gather_py_with_shared_fs(numbers) trans_in_all_beams = [ combine_from_all_shards(trans) for trans in trans_in_all_beams ] if using_numbering_iterator: origin_order = np.argsort(numbers).tolist() trans_in_all_beams = [[trans[ii] for ii in origin_order] for trans in trans_in_all_beams] return trans_in_all_beams
def interactive_FBS(FLAGS): patience = FLAGS.try_times GlobalNames.USE_GPU = FLAGS.use_gpu config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) data_configs = configs['data_configs'] model_configs = configs['model_configs'] timer = Timer() #=================================================================================== #load data INFO('loading data...') timer.tic() vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) valid_dataset = TextLineDataset(data_path=FLAGS.source_path, vocabulary=vocab_src) valid_iterator = DataIterator(dataset=valid_dataset, batch_size=FLAGS.batch_size, use_bucket=True, buffer_size=100000, numbering=True) valid_ref = [] with open(FLAGS.ref_path) as f: for sent in f: valid_ref.append(vocab_tgt.sent2ids(sent)) INFO('Done. Elapsed time {0}'.format(timer.toc())) #=================================================================================== #build Model & Sampler & Validation INFO('Building model...') critic = NMTCriterion(label_smoothing=model_configs['label_smoothing']) INFO(critic) # 2. Move to GPU if GlobalNames.USE_GPU: critic = critic.cuda() timer.tic() fw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) #bw_nmt_model = None bw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) fw_nmt_model.eval() bw_nmt_model.eval() INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('Reloading model parameters...') timer.tic() fw_params = load_model_parameters(FLAGS.fw_model_path, map_location="cpu") bw_params = load_model_parameters(FLAGS.bw_model_path, map_location="cpu") fw_nmt_model.load_state_dict(fw_params) bw_nmt_model.load_state_dict(bw_params) if GlobalNames.USE_GPU: fw_nmt_model.cuda() bw_nmt_model.cuda() INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('begin...') timer.tic() result_numbers = [] result = [] n_words = 0 imt_numbers = [] imt_result = [] imt_n_words = 0 imt_constrains = [[] for ii in range(FLAGS.imt_step)] infer_progress_bar = tqdm(total=len(valid_iterator), desc=' - (Infer)', unit='sents') valid_iter = valid_iterator.build_generator() for batch in valid_iter: batch_result = [] batch_numbers = [] numbers, seqs_x = batch batch_size_t = len(seqs_x) x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU) with torch.no_grad(): word_ids = beam_search(nmt_model=fw_nmt_model, beam_size=FLAGS.beam_size, max_steps=FLAGS.max_steps, src_seqs=x, alpha=FLAGS.alpha) word_ids = word_ids.cpu().numpy().tolist() for sent_t in word_ids: sent_t = [[wid for wid in line if wid != PAD] for line in sent_t] result.append(sent_t) batch_result.append(sent_t[0]) n_words += len(sent_t[0]) result_numbers += numbers imt_numbers += numbers batch_numbers += numbers batch_ref = [valid_ref[ii] for ii in batch_numbers] last_sents = copy.deepcopy(batch_result) constrains = [[[] for ii in range(patience)] for jj in range(batch_size_t)] positions = [[[] for ii in range(patience)] for jj in range(batch_size_t)] for idx in range(FLAGS.imt_step): cons, pos = sample_constrains(last_sents, batch_ref, patience) for ii in range(batch_size_t): for jj in range(patience): constrains[ii][jj].append(cons[ii][jj]) positions[ii][jj].append(pos[ii][jj]) #print(positions) imt_constrains[idx].append([vocab_tgt.ids2sent(c) for c in cons]) bidirection = False if FLAGS.bidirection: bidirection = True with torch.no_grad(): constrained_word_ids, positions = fixwords_beam_search( fw_nmt_model=fw_nmt_model, bw_nmt_model=bw_nmt_model, beam_size=FLAGS.beam_size, max_steps=FLAGS.max_steps, src_seqs=x, alpha=FLAGS.alpha, constrains=constrains, positions=positions, last_sentences=last_sents, imt_step=idx + 1, bidirection=bidirection) constrained_word_ids = constrained_word_ids.cpu().numpy().tolist() last_sents = [] for i, sent_t in enumerate(constrained_word_ids): sent_t = [[wid for wid in line if wid != PAD] for line in sent_t] if idx == FLAGS.imt_step - 1: imt_result.append(copy.deepcopy(sent_t)) imt_n_words += len(sent_t[0]) samples = [] for trans in sent_t: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(w) samples.append(sample) sent_t = [] for ii in range(len(samples)): if ii % FLAGS.beam_size == 0: sent_t.append(samples[ii]) BLEU = [] for sample in sent_t: bleu, _ = bleuScore(sample, batch_ref[i]) BLEU.append(bleu) # print("BLEU: ", BLEU) order = np.argsort(BLEU).tolist() order = order[::-1] # print("order: ", order) sent_t = [sent_t[ii] for ii in order] last_sents.append(sent_t[0]) if FLAGS.online_learning and idx == FLAGS.imt_step - 1: seqs_y = [] for sent in last_sents: sent = [BOS] + sent seqs_y.append(sent) compute_forward(fw_nmt_model, critic, x, torch.Tensor(seqs_y).long().cuda()) seqs_y = [sent[::-1] for sent in seqs_y] for ii in range(len(seqs_y)): seqs_y[ii][0] = BOS seqs_y[ii][-1] = EOS compute_forward(bw_nmt_model, critic, x, torch.Tensor(seqs_y).long().cuda()) infer_progress_bar.update(batch_size_t) infer_progress_bar.close() INFO('Done. Speed: {0:.2f} words/sec'.format( n_words / (timer.toc(return_seconds=True)))) translation = [] for sent in result: samples = [] for trans in sent: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(vocab_tgt.id2token(w)) samples.append(vocab_tgt.tokenizer.detokenize(sample)) translation.append(samples) origin_order = np.argsort(result_numbers).tolist() translation = [translation[ii] for ii in origin_order] keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min( FLAGS.beam_size, FLAGS.keep_n) outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)] with batch_open(outputs, 'w') as handles: for trans in translation: for i in range(keep_n): if i < len(trans): handles[i].write('%s\n' % trans[i]) else: handles[i].write('%s\n' % 'eos') imt_translation = [] for sent in imt_result: samples = [] for trans in sent: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(w) samples.append(sample) imt_translation.append(samples) origin_order = np.argsort(imt_numbers).tolist() imt_translation = [imt_translation[ii] for ii in origin_order] for idx in range(FLAGS.imt_step): imt_constrains[idx] = [ ' '.join(imt_constrains[idx][ii]) + '\n' for ii in origin_order ] with open('%s.cons%d' % (FLAGS.saveto, idx), 'w') as f: f.writelines(imt_constrains[idx]) bleu_translation = [] for idx, sent in enumerate(imt_translation): samples = [] for ii in range(len(sent)): if ii % FLAGS.beam_size == 0: samples.append(sent[ii]) BLEU = [] for sample in samples: bleu, _ = bleuScore(sample, valid_ref[idx]) BLEU.append(bleu) #print("BLEU: ", BLEU) order = np.argsort(BLEU).tolist() order = order[::-1] #print("order: ", order) samples = [vocab_tgt.ids2sent(samples[ii]) for ii in order] bleu_translation.append(samples) #keep_n = FLAGS.beam_size*patience if FLAGS.keep_n <= 0 else min(FLAGS.beam_size*patience, FLAGS.keep_n) keep_n = patience outputs = ['%s.imt%d' % (FLAGS.saveto, i) for i in range(keep_n)] with batch_open(outputs, 'w') as handles: for trans in bleu_translation: for i in range(keep_n): if i < len(trans): handles[i].write('%s\n' % trans[i]) else: handles[i].write('%s\n' % 'eos')