def ensemble_translate(FLAGS): GlobalNames.USE_GPU = FLAGS.use_gpu config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) data_configs = configs['data_configs'] model_configs = configs['model_configs'] timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) valid_dataset = TextLineDataset(data_path=FLAGS.source_path, vocabulary=vocab_src) valid_iterator = DataIterator(dataset=valid_dataset, batch_size=FLAGS.batch_size, use_bucket=True, buffer_size=100000, numbering=True) INFO('Done. Elapsed time {0}'.format(timer.toc())) # ================================================================================== # # Build Model & Sampler & Validation INFO('Building model...') timer.tic() nmt_models = [] model_path = FLAGS.model_path for ii in range(len(model_path)): nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) nmt_model.eval() INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('Reloading model parameters...') timer.tic() params = load_model_parameters(model_path[ii], map_location="cpu") nmt_model.load_state_dict(params) if GlobalNames.USE_GPU: nmt_model.cuda() nmt_models.append(nmt_model) INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('Begin...') result_numbers = [] result = [] n_words = 0 timer.tic() infer_progress_bar = tqdm(total=len(valid_iterator), desc=' - (Infer) ', unit="sents") valid_iter = valid_iterator.build_generator() for batch in valid_iter: numbers, seqs_x = batch batch_size_t = len(seqs_x) x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU) with torch.no_grad(): word_ids = ensemble_beam_search(nmt_models=nmt_models, beam_size=FLAGS.beam_size, max_steps=FLAGS.max_steps, src_seqs=x, alpha=FLAGS.alpha) word_ids = word_ids.cpu().numpy().tolist() # Append result for sent_t in word_ids: sent_t = [[wid for wid in line if wid != PAD] for line in sent_t] result.append(sent_t) n_words += len(sent_t[0]) infer_progress_bar.update(batch_size_t) infer_progress_bar.close() INFO('Done. Speed: {0:.2f} words/sec'.format( n_words / (timer.toc(return_seconds=True)))) translation = [] for sent in result: samples = [] for trans in sent: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(vocab_tgt.id2token(w)) samples.append(vocab_tgt.tokenizer.detokenize(sample)) translation.append(samples) # resume the ordering origin_order = np.argsort(result_numbers).tolist() translation = [translation[ii] for ii in origin_order] keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min( FLAGS.beam_size, FLAGS.keep_n) outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)] with batch_open(outputs, 'w') as handles: for trans in translation: for i in range(keep_n): if i < len(trans): handles[i].write('%s\n' % trans[i]) else: handles[i].write('%s\n' % 'eos')
def load_or_extract_near_vocab(config_path, model_path, save_to, save_to_full, init_perturb_rate=0, batch_size=50, top_reserve=12, all_with_UNK=False, reload=True, emit_as_id=False): """based on the embedding parameter from Encoder, extract near vocabulary for all words return: dictionary of vocabulary of near vocabs; and a the saved file :param config_path: (string) victim configs (for training data and vocabulary) :param model_path: (string) victim model path for trained embeddings :param save_to: (string) directory to store distilled near-vocab :param save_to_full: (string) directory to store full near-vocab :param init_perturb_rate: (float) the weight-adjustment for perturb :param batch_size: (integer) extract near vocab by batched cosine/Euclidean-similarity :param top_reserve: (integer) at most reserve top-k near candidates :param all_with_UNK: during generation, add UNK to all tokens as a candidate :param reload: reload from the save_to_path if previous record exists :param emit_as_id: (boolean) the key in return will be token ids instead of token """ # load configs with open(config_path.strip()) as f: configs = yaml.load(f) data_configs = configs["data_configs"] model_configs = configs["model_configs"] # load vocabulary file src_vocab = Vocabulary(**data_configs["vocabularies"][0]) # load embedding from model emb = nn.Embedding(num_embeddings=src_vocab.max_n_words, embedding_dim=model_configs["d_word_vec"], padding_idx=PAD ) model_params = torch.load(model_path, map_location="cpu") emb.load_state_dict({"weight": model_params["model"]["encoder.embeddings.embeddings.weight"]}, strict=True) len_mat = torch.sum(emb.weight**2, dim=1)**0.5 # length of the embeddings if os.path.exists(save_to) and reload: print("load from %s:" % save_to) return load_perturb_weight(save_to, src_vocab, emit_as_id) else: print("collect near candidates for vocabulary") avg_dist = 0 avg_std = [] counter = 0 word2p = OrderedDict() word2near_vocab = OrderedDict() # omit similar vocabulary file (batched) with open(save_to, "w") as similar_vocab, open(save_to_full, "w") as full_similar_vocab: # every batched vocabulary collect average E-dist for i in range((src_vocab.max_n_words//batch_size)+1): if i*batch_size == src_vocab.max_n_words: break index = torch.tensor(range(i*batch_size, min(src_vocab.max_n_words, (i+1)*batch_size), 1)) # extract embedding data slice_emb = emb(index) collect_len = torch.mm(len_mat.narrow(0, i * batch_size, min(src_vocab.max_n_words, (i+1)*batch_size)-i*batch_size).unsqueeze(1), len_mat.unsqueeze(0)) # filter top 10 nearest vocab, then filter with Eul-distance within certain range similarity = torch.mm(slice_emb, emb.weight.t()).div(collect_len) # get value and index topk_index = similarity.topk(top_reserve, dim=1)[1] sliceemb = slice_emb.unsqueeze(dim=1).repeat(1, top_reserve, 1) # [batch_size, 1*8, dim] E_dist = ((emb(topk_index)-sliceemb)**2).sum(dim=-1)**0.5 # print("avg Euclidean distance:", E_dist) avg_dist += E_dist.mean() avg_std += [E_dist.std(dim=1).mean()] counter += 1 avg_dist = avg_dist.item() / counter # print(avg_dist) # tensor object # print(avg_std) # output near candidates to file and return dictionary for i in range((src_vocab.max_n_words//batch_size)+1): if i*batch_size == src_vocab.max_n_words: break index = torch.tensor(range(i*batch_size, min(src_vocab.max_n_words, (i+1)*batch_size), 1)) # extract embedding data slice_emb = emb(index) collect_len = torch.mm(len_mat.narrow(0, i * batch_size, min(src_vocab.max_n_words, (i+1)*batch_size)-i*batch_size).unsqueeze(1), len_mat.unsqueeze(0)) # filter top k nearest vocab with cosine-similarity similarity = torch.mm(slice_emb, emb.weight.t()).div(collect_len) topk_val, topk_indices = similarity.topk(top_reserve, dim=1) # calculate E-dist sliceemb = slice_emb.unsqueeze(dim=1).repeat(1, top_reserve, 1) # [batch_size, 1*topk, dim] E_dist = ((emb(topk_indices)-sliceemb)**2).sum(dim=-1)**0.5 topk_val = E_dist.cpu().detach().numpy() topk_indices = topk_indices.cpu().detach().numpy() for j in range(topk_val.shape[0]): bingo = 0. src_word_id = j + i*batch_size src_word = src_vocab.id2token(src_word_id) near_vocab = [] similar_vocab.write(src_word + "\t") full_similar_vocab.write(src_word + "\t") # there is no candidates for reserved tokens if src_word_id in [PAD, EOS, BOS, UNK]: near_cand_id = src_word_id near_cand = src_vocab.id2token(near_cand_id) full_similar_vocab.write(near_cand + "\t") similar_vocab.write(near_cand + "\t") bingo = 1 if emit_as_id: near_vocab += [near_cand_id] else: near_vocab += [near_cand] else: # extract near candidates according to cos-dist within averaged E-dist for k in range(1, topk_val.shape[1]): near_cand_id = topk_indices[j][k] near_cand = src_vocab.id2token(near_cand_id) full_similar_vocab.write(near_cand + "\t") if topk_val[j][k] < avg_dist and (near_cand_id not in [PAD, EOS, BOS]): bingo += 1 similar_vocab.write(near_cand + "\t") if emit_as_id: near_vocab += [near_cand_id] else: near_vocab += [near_cand] # additionally add UNK as candidates if bingo == 0 or all_with_UNK: last_cand_ids = [UNK] for final_reserve_id in last_cand_ids: last_cand = src_vocab.id2token(final_reserve_id) similar_vocab.write(last_cand + "\t") if emit_as_id: near_vocab += [final_reserve_id] else: near_vocab += [last_cand] probability = bingo/(len(src_word)*top_reserve) if init_perturb_rate != 0: probability *= init_perturb_rate similar_vocab.write("\t"+str(probability)+"\n") full_similar_vocab.write("\t"+str(probability)+"\n") if emit_as_id: word2near_vocab[src_word_id] = near_vocab word2p[src_word_id] = probability else: word2near_vocab[src_word] = near_vocab word2p[src_word] = probability return word2p, word2near_vocab
def initial_random_perturb(config_path, inputs, w2p, w2vocab, mode="len_based", key_type="token", show_bleu=False): """ batched random perturb, perturb is based on random probability from the collected candidates meant to test initial attack rate. :param config_path: victim configs :param inputs: raw batched input (list) sequences in [batch_size, seq_len] :param w2p: indicates how likely a word is perturbed :param w2vocab: near candidates :param mode: based on word2near_vocab, how to distribute likelihood among candidates :param key_type: inputs are given by raw sequences of tokens or tokenized labels :param show_bleu: whether to show bleu of perturbed seqs (compare to original seqs) :return: list of perturbed inputs and list of perturbed flags """ np.random.seed(int(time.time())) assert mode in ["uniform", "len_based"], "Mode must be in uniform or multinomial." assert key_type in ["token", "label"], "inputs key type must be token or label." # load configs with open(config_path.strip()) as f: configs = yaml.load(f) data_configs = configs["data_configs"] # load vocabulary file and tokenize src_vocab = Vocabulary(**data_configs["vocabularies"][0]) perturbed_results = [] flags = [] for sent in inputs: if np.random.uniform() < 0.5: # perturb the sentence perturbed_sent = [] if key_type == "token": tokenized_sent = src_vocab.tokenizer.tokenize(sent) for word in tokenized_sent: if np.random.uniform() < w2p[word]: # need to perturb on lexical level if mode == "uniform": # uniform choose from candidates: perturbed_sent += [w2vocab[word][np.random.choice(len(w2vocab[word]), 1)[0]]] elif mode == "len_based": # weighted choose from candidates: weights = [1./(1+abs(len(word)-len(c))) for c in w2vocab[word]] norm_weights = [c/sum(weights) for c in weights] perturbed_sent += [w2vocab[word][np.random.choice(len(w2vocab[word]), 1, p=norm_weights )[0]]] else: perturbed_sent += [word] # print(perturbed_sent) # yield same form of sequences of tokens perturbed_sent = src_vocab.tokenizer.detokenize(perturbed_sent) elif key_type == "label": # tokenized labels for word_index in sent: word = src_vocab.id2token(word_index) if np.random.uniform() < w2p[word]: if mode == "uniform": # uniform choose from candidates: perturbed_label = src_vocab.token2id(w2vocab[word][np.random.choice( len(w2vocab[word]), 1 )[0]]) perturbed_sent += [perturbed_label] elif mode == "len_based": # weighted choose from candidates: weights = [1. / (1 + abs(len(word) - len(c))) for c in w2vocab[word]] norm_weights = [c / sum(weights) for c in weights] perturbed_label = src_vocab.token2id(w2vocab[word][np.random.choice(len(w2vocab[word]), 1, p=norm_weights )[0]]) perturbed_sent += [perturbed_label] else: perturbed_sent += [word_index] perturbed_results += [perturbed_sent] flags += [1] # out.write(perturbed_sent + "\n") else: perturbed_results += [sent] flags += [0] return perturbed_results, flags
def ensemble_inference(valid_iterator, models, vocab_tgt: Vocabulary, batch_size, max_steps, beam_size=5, alpha=-1.0, rank=0, world_size=1, using_numbering_iterator=True): for model in models: model.eval() trans_in_all_beams = [[] for _ in range(beam_size)] # assert keep_n_beams <= beam_size if using_numbering_iterator: numbers = [] if rank == 0: infer_progress_bar = tqdm(total=len(valid_iterator), desc=' - (Infer) ', unit="sents") else: infer_progress_bar = None valid_iter = valid_iterator.build_generator(batch_size=batch_size) for batch in valid_iter: seq_numbers = batch[0] if using_numbering_iterator: numbers += seq_numbers seqs_x = batch[1] if infer_progress_bar is not None: infer_progress_bar.update(len(seqs_x) * world_size) x = prepare_data(seqs_x, seqs_y=None, cuda=Constants.USE_GPU) with torch.no_grad(): word_ids = ensemble_beam_search(nmt_models=models, beam_size=beam_size, max_steps=max_steps, src_seqs=x, alpha=alpha) word_ids = word_ids.cpu().numpy().tolist() # Append result for sent_t in word_ids: for ii, sent_ in enumerate(sent_t): sent_ = vocab_tgt.ids2sent(sent_) if sent_ == "": sent_ = '%s' % vocab_tgt.id2token(vocab_tgt.eos) trans_in_all_beams[ii].append(sent_) if infer_progress_bar is not None: infer_progress_bar.close() if world_size > 1: if using_numbering_iterator: numbers = dist.all_gather_py_with_shared_fs(numbers) trans_in_all_beams = [ combine_from_all_shards(trans) for trans in trans_in_all_beams ] if using_numbering_iterator: origin_order = np.argsort(numbers).tolist() trans_in_all_beams = [[trans[ii] for ii in origin_order] for trans in trans_in_all_beams] return trans_in_all_beams
def translate(FLAGS): GlobalNames.USE_GPU = FLAGS.use_gpu if FLAGS.multi_gpu: if hvd is None or distributed is None: ERROR("Distributed training is disable. Please check the installation of Horovod.") hvd.init() world_size = hvd.size() rank = hvd.rank() if GlobalNames.USE_GPU: torch.cuda.set_device(hvd.local_rank()) else: world_size = 1 rank = 0 if rank != 0: close_logging() config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) data_configs = configs['data_configs'] model_configs = configs['model_configs'] timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) valid_dataset = TextLineDataset(data_path=FLAGS.source_path, vocabulary=vocab_src) valid_iterator = DataIterator(dataset=valid_dataset, batch_size=FLAGS.batch_size, use_bucket=True, buffer_size=100000, numbering=True, world_size=world_size, rank=rank ) INFO('Done. Elapsed time {0}'.format(timer.toc())) # ================================================================================== # # Build Model & Sampler & Validation INFO('Building model...') timer.tic() nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) nmt_model.eval() INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('Reloading model parameters...') timer.tic() params = load_model_parameters(FLAGS.model_path, map_location="cpu") nmt_model.load_state_dict(params, strict=False) if GlobalNames.USE_GPU: nmt_model.cuda() INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('Begin...') result_numbers = [] result = [] n_words = 0 timer.tic() if rank == 0: infer_progress_bar = tqdm(total=len(valid_iterator), desc=' - (Infer) ', unit="sents") else: infer_progress_bar = None valid_iter = valid_iterator.build_generator() for batch in valid_iter: numbers, seqs_x = batch batch_size_t = len(seqs_x) x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU) with torch.no_grad(): word_ids = beam_search(nmt_model=nmt_model, beam_size=FLAGS.beam_size, max_steps=FLAGS.max_steps, src_seqs=x, alpha=FLAGS.alpha) word_ids = word_ids.cpu().numpy().tolist() # Append result for sent_t in word_ids: sent_t = [[wid for wid in line if wid != PAD] for line in sent_t] result.append(sent_t) n_words += len(sent_t[0]) result_numbers += numbers if rank == 0: infer_progress_bar.update(batch_size_t * world_size) if rank == 0: infer_progress_bar.close() if FLAGS.multi_gpu: n_words = sum(distributed.all_gather(n_words)) INFO('Done. Speed: {0:.2f} words/sec'.format(n_words / (timer.toc(return_seconds=True)))) if FLAGS.multi_gpu: result_gathered = distributed.all_gather_with_shared_fs(result) result = [] for lines in itertools.zip_longest(*result_gathered, fillvalue=None): for line in lines: if line is not None: result.append(line) result_numbers_gathered = distributed.all_gather_with_shared_fs(result_numbers) result_numbers = [] for numbers in itertools.zip_longest(*result_numbers_gathered, fillvalue=None): for num in numbers: if num is not None: result_numbers.append(num) if rank == 0: translation = [] for sent in result: samples = [] for trans in sent: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(vocab_tgt.id2token(w)) samples.append(vocab_tgt.tokenizer.detokenize(sample)) translation.append(samples) # resume the ordering origin_order = np.argsort(result_numbers).tolist() translation = [translation[ii] for ii in origin_order] keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min(FLAGS.beam_size, FLAGS.keep_n) outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)] with batch_open(outputs, 'w') as handles: for trans in translation: for i in range(keep_n): if i < len(trans): handles[i].write('%s\n' % trans[i]) else: handles[i].write('%s\n' % 'eos')
def interactive_FBS(FLAGS): patience = FLAGS.try_times GlobalNames.USE_GPU = FLAGS.use_gpu config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) data_configs = configs['data_configs'] model_configs = configs['model_configs'] timer = Timer() #=================================================================================== #load data INFO('loading data...') timer.tic() vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) valid_dataset = TextLineDataset(data_path=FLAGS.source_path, vocabulary=vocab_src) valid_iterator = DataIterator(dataset=valid_dataset, batch_size=FLAGS.batch_size, use_bucket=True, buffer_size=100000, numbering=True) valid_ref = [] with open(FLAGS.ref_path) as f: for sent in f: valid_ref.append(vocab_tgt.sent2ids(sent)) INFO('Done. Elapsed time {0}'.format(timer.toc())) #=================================================================================== #build Model & Sampler & Validation INFO('Building model...') critic = NMTCriterion(label_smoothing=model_configs['label_smoothing']) INFO(critic) # 2. Move to GPU if GlobalNames.USE_GPU: critic = critic.cuda() timer.tic() fw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) #bw_nmt_model = None bw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) fw_nmt_model.eval() bw_nmt_model.eval() INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('Reloading model parameters...') timer.tic() fw_params = load_model_parameters(FLAGS.fw_model_path, map_location="cpu") bw_params = load_model_parameters(FLAGS.bw_model_path, map_location="cpu") fw_nmt_model.load_state_dict(fw_params) bw_nmt_model.load_state_dict(bw_params) if GlobalNames.USE_GPU: fw_nmt_model.cuda() bw_nmt_model.cuda() INFO('Done. Elapsed time {0}'.format(timer.toc())) INFO('begin...') timer.tic() result_numbers = [] result = [] n_words = 0 imt_numbers = [] imt_result = [] imt_n_words = 0 imt_constrains = [[] for ii in range(FLAGS.imt_step)] infer_progress_bar = tqdm(total=len(valid_iterator), desc=' - (Infer)', unit='sents') valid_iter = valid_iterator.build_generator() for batch in valid_iter: batch_result = [] batch_numbers = [] numbers, seqs_x = batch batch_size_t = len(seqs_x) x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU) with torch.no_grad(): word_ids = beam_search(nmt_model=fw_nmt_model, beam_size=FLAGS.beam_size, max_steps=FLAGS.max_steps, src_seqs=x, alpha=FLAGS.alpha) word_ids = word_ids.cpu().numpy().tolist() for sent_t in word_ids: sent_t = [[wid for wid in line if wid != PAD] for line in sent_t] result.append(sent_t) batch_result.append(sent_t[0]) n_words += len(sent_t[0]) result_numbers += numbers imt_numbers += numbers batch_numbers += numbers batch_ref = [valid_ref[ii] for ii in batch_numbers] last_sents = copy.deepcopy(batch_result) constrains = [[[] for ii in range(patience)] for jj in range(batch_size_t)] positions = [[[] for ii in range(patience)] for jj in range(batch_size_t)] for idx in range(FLAGS.imt_step): cons, pos = sample_constrains(last_sents, batch_ref, patience) for ii in range(batch_size_t): for jj in range(patience): constrains[ii][jj].append(cons[ii][jj]) positions[ii][jj].append(pos[ii][jj]) #print(positions) imt_constrains[idx].append([vocab_tgt.ids2sent(c) for c in cons]) bidirection = False if FLAGS.bidirection: bidirection = True with torch.no_grad(): constrained_word_ids, positions = fixwords_beam_search( fw_nmt_model=fw_nmt_model, bw_nmt_model=bw_nmt_model, beam_size=FLAGS.beam_size, max_steps=FLAGS.max_steps, src_seqs=x, alpha=FLAGS.alpha, constrains=constrains, positions=positions, last_sentences=last_sents, imt_step=idx + 1, bidirection=bidirection) constrained_word_ids = constrained_word_ids.cpu().numpy().tolist() last_sents = [] for i, sent_t in enumerate(constrained_word_ids): sent_t = [[wid for wid in line if wid != PAD] for line in sent_t] if idx == FLAGS.imt_step - 1: imt_result.append(copy.deepcopy(sent_t)) imt_n_words += len(sent_t[0]) samples = [] for trans in sent_t: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(w) samples.append(sample) sent_t = [] for ii in range(len(samples)): if ii % FLAGS.beam_size == 0: sent_t.append(samples[ii]) BLEU = [] for sample in sent_t: bleu, _ = bleuScore(sample, batch_ref[i]) BLEU.append(bleu) # print("BLEU: ", BLEU) order = np.argsort(BLEU).tolist() order = order[::-1] # print("order: ", order) sent_t = [sent_t[ii] for ii in order] last_sents.append(sent_t[0]) if FLAGS.online_learning and idx == FLAGS.imt_step - 1: seqs_y = [] for sent in last_sents: sent = [BOS] + sent seqs_y.append(sent) compute_forward(fw_nmt_model, critic, x, torch.Tensor(seqs_y).long().cuda()) seqs_y = [sent[::-1] for sent in seqs_y] for ii in range(len(seqs_y)): seqs_y[ii][0] = BOS seqs_y[ii][-1] = EOS compute_forward(bw_nmt_model, critic, x, torch.Tensor(seqs_y).long().cuda()) infer_progress_bar.update(batch_size_t) infer_progress_bar.close() INFO('Done. Speed: {0:.2f} words/sec'.format( n_words / (timer.toc(return_seconds=True)))) translation = [] for sent in result: samples = [] for trans in sent: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(vocab_tgt.id2token(w)) samples.append(vocab_tgt.tokenizer.detokenize(sample)) translation.append(samples) origin_order = np.argsort(result_numbers).tolist() translation = [translation[ii] for ii in origin_order] keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min( FLAGS.beam_size, FLAGS.keep_n) outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)] with batch_open(outputs, 'w') as handles: for trans in translation: for i in range(keep_n): if i < len(trans): handles[i].write('%s\n' % trans[i]) else: handles[i].write('%s\n' % 'eos') imt_translation = [] for sent in imt_result: samples = [] for trans in sent: sample = [] for w in trans: if w == vocab_tgt.EOS: break sample.append(w) samples.append(sample) imt_translation.append(samples) origin_order = np.argsort(imt_numbers).tolist() imt_translation = [imt_translation[ii] for ii in origin_order] for idx in range(FLAGS.imt_step): imt_constrains[idx] = [ ' '.join(imt_constrains[idx][ii]) + '\n' for ii in origin_order ] with open('%s.cons%d' % (FLAGS.saveto, idx), 'w') as f: f.writelines(imt_constrains[idx]) bleu_translation = [] for idx, sent in enumerate(imt_translation): samples = [] for ii in range(len(sent)): if ii % FLAGS.beam_size == 0: samples.append(sent[ii]) BLEU = [] for sample in samples: bleu, _ = bleuScore(sample, valid_ref[idx]) BLEU.append(bleu) #print("BLEU: ", BLEU) order = np.argsort(BLEU).tolist() order = order[::-1] #print("order: ", order) samples = [vocab_tgt.ids2sent(samples[ii]) for ii in order] bleu_translation.append(samples) #keep_n = FLAGS.beam_size*patience if FLAGS.keep_n <= 0 else min(FLAGS.beam_size*patience, FLAGS.keep_n) keep_n = patience outputs = ['%s.imt%d' % (FLAGS.saveto, i) for i in range(keep_n)] with batch_open(outputs, 'w') as handles: for trans in bleu_translation: for i in range(keep_n): if i < len(trans): handles[i].write('%s\n' % trans[i]) else: handles[i].write('%s\n' % 'eos')