def ensemble_translate(FLAGS):
    GlobalNames.USE_GPU = FLAGS.use_gpu

    config_path = os.path.abspath(FLAGS.config_path)

    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']

    timer = Timer()
    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    valid_dataset = TextLineDataset(data_path=FLAGS.source_path,
                                    vocabulary=vocab_src)

    valid_iterator = DataIterator(dataset=valid_dataset,
                                  batch_size=FLAGS.batch_size,
                                  use_bucket=True,
                                  buffer_size=100000,
                                  numbering=True)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================================================================== #
    # Build Model & Sampler & Validation
    INFO('Building model...')
    timer.tic()

    nmt_models = []

    model_path = FLAGS.model_path

    for ii in range(len(model_path)):

        nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                                n_tgt_vocab=vocab_tgt.max_n_words,
                                **model_configs)
        nmt_model.eval()
        INFO('Done. Elapsed time {0}'.format(timer.toc()))

        INFO('Reloading model parameters...')
        timer.tic()

        params = load_model_parameters(model_path[ii], map_location="cpu")

        nmt_model.load_state_dict(params)

        if GlobalNames.USE_GPU:
            nmt_model.cuda()

        nmt_models.append(nmt_model)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Begin...')
    result_numbers = []
    result = []
    n_words = 0

    timer.tic()

    infer_progress_bar = tqdm(total=len(valid_iterator),
                              desc=' - (Infer)  ',
                              unit="sents")

    valid_iter = valid_iterator.build_generator()
    for batch in valid_iter:

        numbers, seqs_x = batch

        batch_size_t = len(seqs_x)

        x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU)

        with torch.no_grad():
            word_ids = ensemble_beam_search(nmt_models=nmt_models,
                                            beam_size=FLAGS.beam_size,
                                            max_steps=FLAGS.max_steps,
                                            src_seqs=x,
                                            alpha=FLAGS.alpha)

        word_ids = word_ids.cpu().numpy().tolist()

        # Append result
        for sent_t in word_ids:
            sent_t = [[wid for wid in line if wid != PAD] for line in sent_t]
            result.append(sent_t)

            n_words += len(sent_t[0])

        infer_progress_bar.update(batch_size_t)

    infer_progress_bar.close()

    INFO('Done. Speed: {0:.2f} words/sec'.format(
        n_words / (timer.toc(return_seconds=True))))

    translation = []
    for sent in result:
        samples = []
        for trans in sent:
            sample = []
            for w in trans:
                if w == vocab_tgt.EOS:
                    break
                sample.append(vocab_tgt.id2token(w))
            samples.append(vocab_tgt.tokenizer.detokenize(sample))
        translation.append(samples)

    # resume the ordering
    origin_order = np.argsort(result_numbers).tolist()
    translation = [translation[ii] for ii in origin_order]

    keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min(
        FLAGS.beam_size, FLAGS.keep_n)
    outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)]

    with batch_open(outputs, 'w') as handles:
        for trans in translation:
            for i in range(keep_n):
                if i < len(trans):
                    handles[i].write('%s\n' % trans[i])
                else:
                    handles[i].write('%s\n' % 'eos')
def load_or_extract_near_vocab(config_path,
                               model_path,
                               save_to,
                               save_to_full,
                               init_perturb_rate=0,
                               batch_size=50,
                               top_reserve=12,
                               all_with_UNK=False,
                               reload=True,
                               emit_as_id=False):
    """based on the embedding parameter from Encoder, extract near vocabulary for all words
    return: dictionary of vocabulary of near vocabs; and a the saved file
    :param config_path: (string) victim configs (for training data and vocabulary)
    :param model_path: (string) victim model path for trained embeddings
    :param save_to: (string) directory to store distilled near-vocab
    :param save_to_full: (string) directory to store full near-vocab
    :param init_perturb_rate: (float) the weight-adjustment for perturb
    :param batch_size: (integer) extract near vocab by batched cosine/Euclidean-similarity
    :param top_reserve: (integer) at most reserve top-k near candidates
    :param all_with_UNK: during generation, add UNK to all tokens as a candidate
    :param reload: reload from the save_to_path if previous record exists
    :param emit_as_id: (boolean) the key in return will be token ids instead of token
    """
    # load configs
    with open(config_path.strip()) as f:
        configs = yaml.load(f)
    data_configs = configs["data_configs"]
    model_configs = configs["model_configs"]
    # load vocabulary file
    src_vocab = Vocabulary(**data_configs["vocabularies"][0])

    # load embedding from model
    emb = nn.Embedding(num_embeddings=src_vocab.max_n_words,
                       embedding_dim=model_configs["d_word_vec"],
                       padding_idx=PAD
                       )
    model_params = torch.load(model_path, map_location="cpu")
    emb.load_state_dict({"weight": model_params["model"]["encoder.embeddings.embeddings.weight"]},
                        strict=True)
    len_mat = torch.sum(emb.weight**2, dim=1)**0.5  # length of the embeddings

    if os.path.exists(save_to) and reload:
        print("load from %s:" % save_to)
        return load_perturb_weight(save_to, src_vocab, emit_as_id)
    else:
        print("collect near candidates for vocabulary")
        avg_dist = 0
        avg_std = []
        counter = 0
        word2p = OrderedDict()
        word2near_vocab = OrderedDict()
        # omit similar vocabulary file (batched)
        with open(save_to, "w") as similar_vocab, open(save_to_full, "w") as full_similar_vocab:
            # every batched vocabulary collect average E-dist
            for i in range((src_vocab.max_n_words//batch_size)+1):
                if i*batch_size == src_vocab.max_n_words:
                    break

                index = torch.tensor(range(i*batch_size,
                                           min(src_vocab.max_n_words, (i+1)*batch_size),
                                           1))
                # extract embedding data
                slice_emb = emb(index)
                collect_len = torch.mm(len_mat.narrow(0, i * batch_size, min(src_vocab.max_n_words, (i+1)*batch_size)-i*batch_size).unsqueeze(1),
                                       len_mat.unsqueeze(0))
                # filter top 10 nearest vocab, then filter with Eul-distance within certain range
                similarity = torch.mm(slice_emb,
                                      emb.weight.t()).div(collect_len)
                # get value and index
                topk_index = similarity.topk(top_reserve, dim=1)[1]

                sliceemb = slice_emb.unsqueeze(dim=1).repeat(1, top_reserve, 1)  # [batch_size, 1*8, dim]
                E_dist = ((emb(topk_index)-sliceemb)**2).sum(dim=-1)**0.5
                # print("avg Euclidean distance:", E_dist)
                avg_dist += E_dist.mean()
                avg_std += [E_dist.std(dim=1).mean()]
                counter += 1
            avg_dist = avg_dist.item() / counter
            # print(avg_dist)  # tensor object
            # print(avg_std)

            # output near candidates to file and return dictionary
            for i in range((src_vocab.max_n_words//batch_size)+1):
                if i*batch_size == src_vocab.max_n_words:
                    break
                index = torch.tensor(range(i*batch_size,
                                           min(src_vocab.max_n_words, (i+1)*batch_size),
                                           1))
                # extract embedding data
                slice_emb = emb(index)
                collect_len = torch.mm(len_mat.narrow(0, i * batch_size, min(src_vocab.max_n_words, (i+1)*batch_size)-i*batch_size).unsqueeze(1),
                                       len_mat.unsqueeze(0))
                # filter top k nearest vocab with cosine-similarity
                similarity = torch.mm(slice_emb,
                                      emb.weight.t()).div(collect_len)
                topk_val, topk_indices = similarity.topk(top_reserve, dim=1)
                # calculate E-dist
                sliceemb = slice_emb.unsqueeze(dim=1).repeat(1, top_reserve, 1)  # [batch_size, 1*topk, dim]
                E_dist = ((emb(topk_indices)-sliceemb)**2).sum(dim=-1)**0.5

                topk_val = E_dist.cpu().detach().numpy()
                topk_indices = topk_indices.cpu().detach().numpy()
                for j in range(topk_val.shape[0]):
                    bingo = 0.
                    src_word_id = j + i*batch_size

                    src_word = src_vocab.id2token(src_word_id)
                    near_vocab = []

                    similar_vocab.write(src_word + "\t")
                    full_similar_vocab.write(src_word + "\t")

                    # there is no candidates for reserved tokens
                    if src_word_id in [PAD, EOS, BOS, UNK]:
                        near_cand_id = src_word_id
                        near_cand = src_vocab.id2token(near_cand_id)

                        full_similar_vocab.write(near_cand + "\t")
                        similar_vocab.write(near_cand + "\t")
                        bingo = 1
                        if emit_as_id:
                            near_vocab += [near_cand_id]
                        else:
                            near_vocab += [near_cand]
                    else:
                        # extract near candidates according to cos-dist within averaged E-dist
                        for k in range(1, topk_val.shape[1]):
                            near_cand_id = topk_indices[j][k]
                            near_cand = src_vocab.id2token(near_cand_id)
                            full_similar_vocab.write(near_cand + "\t")
                            if topk_val[j][k] < avg_dist and (near_cand_id not in [PAD, EOS, BOS]):
                                bingo += 1
                                similar_vocab.write(near_cand + "\t")
                                if emit_as_id:
                                    near_vocab += [near_cand_id]
                                else:
                                    near_vocab += [near_cand]
                        # additionally add UNK as candidates
                        if bingo == 0 or all_with_UNK:
                            last_cand_ids = [UNK]
                            for final_reserve_id in last_cand_ids:
                                last_cand = src_vocab.id2token(final_reserve_id)
                                similar_vocab.write(last_cand + "\t")
                                if emit_as_id:
                                    near_vocab += [final_reserve_id]
                                else:
                                    near_vocab += [last_cand]

                    probability = bingo/(len(src_word)*top_reserve)
                    if init_perturb_rate != 0:
                        probability *= init_perturb_rate
                    similar_vocab.write("\t"+str(probability)+"\n")
                    full_similar_vocab.write("\t"+str(probability)+"\n")
                    if emit_as_id:
                        word2near_vocab[src_word_id] = near_vocab
                        word2p[src_word_id] = probability
                    else:
                        word2near_vocab[src_word] = near_vocab
                        word2p[src_word] = probability
        return word2p, word2near_vocab
def initial_random_perturb(config_path,
                           inputs,
                           w2p, w2vocab,
                           mode="len_based",
                           key_type="token",
                           show_bleu=False):
    """
    batched random perturb, perturb is based on random probability from the collected candidates
    meant to test initial attack rate.
    :param config_path: victim configs
    :param inputs: raw batched input (list) sequences in [batch_size, seq_len]
    :param w2p: indicates how likely a word is perturbed
    :param w2vocab: near candidates
    :param mode: based on word2near_vocab, how to distribute likelihood among candidates
    :param key_type: inputs are given by raw sequences of tokens or tokenized labels
    :param show_bleu: whether to show bleu of perturbed seqs (compare to original seqs)
    :return: list of perturbed inputs and list of perturbed flags
    """
    np.random.seed(int(time.time()))
    assert mode in ["uniform", "len_based"], "Mode must be in uniform or multinomial."
    assert key_type in ["token", "label"], "inputs key type must be token or label."
    # load configs
    with open(config_path.strip()) as f:
        configs = yaml.load(f)
    data_configs = configs["data_configs"]

    # load vocabulary file and tokenize
    src_vocab = Vocabulary(**data_configs["vocabularies"][0])
    perturbed_results = []
    flags = []
    for sent in inputs:
        if np.random.uniform() < 0.5:  # perturb the sentence
            perturbed_sent = []
            if key_type == "token":
                tokenized_sent = src_vocab.tokenizer.tokenize(sent)
                for word in tokenized_sent:
                    if np.random.uniform() < w2p[word]:
                        # need to perturb on lexical level
                        if mode == "uniform":
                            # uniform choose from candidates:
                            perturbed_sent += [w2vocab[word][np.random.choice(len(w2vocab[word]),
                                                                              1)[0]]]
                        elif mode == "len_based":
                            # weighted choose from candidates:
                            weights = [1./(1+abs(len(word)-len(c))) for c in w2vocab[word]]
                            norm_weights = [c/sum(weights) for c in weights]
                            perturbed_sent += [w2vocab[word][np.random.choice(len(w2vocab[word]),
                                                                              1,
                                                                              p=norm_weights
                                                                              )[0]]]
                    else:
                        perturbed_sent += [word]
                # print(perturbed_sent)  # yield same form of sequences of tokens
                perturbed_sent = src_vocab.tokenizer.detokenize(perturbed_sent)
            elif key_type == "label":  # tokenized labels
                for word_index in sent:
                    word = src_vocab.id2token(word_index)
                    if np.random.uniform() < w2p[word]:
                        if mode == "uniform":
                            # uniform choose from candidates:
                            perturbed_label = src_vocab.token2id(w2vocab[word][np.random.choice(
                                len(w2vocab[word]), 1
                            )[0]])
                            perturbed_sent += [perturbed_label]
                        elif mode == "len_based":
                            # weighted choose from candidates:
                            weights = [1. / (1 + abs(len(word) - len(c))) for c in w2vocab[word]]
                            norm_weights = [c / sum(weights) for c in weights]
                            perturbed_label = src_vocab.token2id(w2vocab[word][np.random.choice(len(w2vocab[word]),
                                                                                                1,
                                                                                                p=norm_weights
                                                                                                )[0]])
                            perturbed_sent += [perturbed_label]
                    else:
                        perturbed_sent += [word_index]
            perturbed_results += [perturbed_sent]
            flags += [1]
            # out.write(perturbed_sent + "\n")
        else:
            perturbed_results += [sent]
            flags += [0]
    return perturbed_results, flags
Exemple #4
0
def ensemble_inference(valid_iterator,
                       models,
                       vocab_tgt: Vocabulary,
                       batch_size,
                       max_steps,
                       beam_size=5,
                       alpha=-1.0,
                       rank=0,
                       world_size=1,
                       using_numbering_iterator=True):
    for model in models:
        model.eval()

    trans_in_all_beams = [[] for _ in range(beam_size)]

    # assert keep_n_beams <= beam_size

    if using_numbering_iterator:
        numbers = []

    if rank == 0:
        infer_progress_bar = tqdm(total=len(valid_iterator),
                                  desc=' - (Infer)  ',
                                  unit="sents")
    else:
        infer_progress_bar = None

    valid_iter = valid_iterator.build_generator(batch_size=batch_size)

    for batch in valid_iter:

        seq_numbers = batch[0]

        if using_numbering_iterator:
            numbers += seq_numbers

        seqs_x = batch[1]

        if infer_progress_bar is not None:
            infer_progress_bar.update(len(seqs_x) * world_size)

        x = prepare_data(seqs_x, seqs_y=None, cuda=Constants.USE_GPU)

        with torch.no_grad():
            word_ids = ensemble_beam_search(nmt_models=models,
                                            beam_size=beam_size,
                                            max_steps=max_steps,
                                            src_seqs=x,
                                            alpha=alpha)

        word_ids = word_ids.cpu().numpy().tolist()

        # Append result
        for sent_t in word_ids:
            for ii, sent_ in enumerate(sent_t):
                sent_ = vocab_tgt.ids2sent(sent_)
                if sent_ == "":
                    sent_ = '%s' % vocab_tgt.id2token(vocab_tgt.eos)
                trans_in_all_beams[ii].append(sent_)

    if infer_progress_bar is not None:
        infer_progress_bar.close()

    if world_size > 1:
        if using_numbering_iterator:
            numbers = dist.all_gather_py_with_shared_fs(numbers)

        trans_in_all_beams = [
            combine_from_all_shards(trans) for trans in trans_in_all_beams
        ]

    if using_numbering_iterator:
        origin_order = np.argsort(numbers).tolist()
        trans_in_all_beams = [[trans[ii] for ii in origin_order]
                              for trans in trans_in_all_beams]

    return trans_in_all_beams
Exemple #5
0
def translate(FLAGS):
    GlobalNames.USE_GPU = FLAGS.use_gpu

    if FLAGS.multi_gpu:

        if hvd is None or distributed is None:
            ERROR("Distributed training is disable. Please check the installation of Horovod.")

        hvd.init()
        world_size = hvd.size()
        rank = hvd.rank()

        if GlobalNames.USE_GPU:
            torch.cuda.set_device(hvd.local_rank())
    else:
        world_size = 1
        rank = 0

    if rank != 0:
        close_logging()

    config_path = os.path.abspath(FLAGS.config_path)

    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']

    timer = Timer()
    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    valid_dataset = TextLineDataset(data_path=FLAGS.source_path,
                                    vocabulary=vocab_src)

    valid_iterator = DataIterator(dataset=valid_dataset,
                                  batch_size=FLAGS.batch_size,
                                  use_bucket=True,
                                  buffer_size=100000,
                                  numbering=True,
                                  world_size=world_size,
                                  rank=rank
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================================================================== #
    # Build Model & Sampler & Validation
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    nmt_model.eval()
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Reloading model parameters...')
    timer.tic()

    params = load_model_parameters(FLAGS.model_path, map_location="cpu")

    nmt_model.load_state_dict(params, strict=False)

    if GlobalNames.USE_GPU:
        nmt_model.cuda()

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Begin...')

    result_numbers = []
    result = []
    n_words = 0

    timer.tic()

    if rank == 0:
        infer_progress_bar = tqdm(total=len(valid_iterator),
                                  desc=' - (Infer)  ',
                                  unit="sents")
    else:
        infer_progress_bar = None

    valid_iter = valid_iterator.build_generator()

    for batch in valid_iter:

        numbers, seqs_x = batch

        batch_size_t = len(seqs_x)

        x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU)

        with torch.no_grad():
            word_ids = beam_search(nmt_model=nmt_model, beam_size=FLAGS.beam_size, max_steps=FLAGS.max_steps,
                                   src_seqs=x, alpha=FLAGS.alpha)

        word_ids = word_ids.cpu().numpy().tolist()

        # Append result
        for sent_t in word_ids:
            sent_t = [[wid for wid in line if wid != PAD] for line in sent_t]
            result.append(sent_t)

            n_words += len(sent_t[0])

        result_numbers += numbers

        if rank == 0:
            infer_progress_bar.update(batch_size_t * world_size)

    if rank == 0:
        infer_progress_bar.close()

    if FLAGS.multi_gpu:
        n_words = sum(distributed.all_gather(n_words))

    INFO('Done. Speed: {0:.2f} words/sec'.format(n_words / (timer.toc(return_seconds=True))))

    if FLAGS.multi_gpu:

        result_gathered = distributed.all_gather_with_shared_fs(result)

        result = []

        for lines in itertools.zip_longest(*result_gathered, fillvalue=None):
            for line in lines:
                if line is not None:
                    result.append(line)

        result_numbers_gathered = distributed.all_gather_with_shared_fs(result_numbers)

        result_numbers = []

        for numbers in itertools.zip_longest(*result_numbers_gathered, fillvalue=None):
            for num in numbers:
                if num is not None:
                    result_numbers.append(num)

    if rank == 0:
        translation = []
        for sent in result:
            samples = []
            for trans in sent:
                sample = []
                for w in trans:
                    if w == vocab_tgt.EOS:
                        break
                    sample.append(vocab_tgt.id2token(w))
                samples.append(vocab_tgt.tokenizer.detokenize(sample))
            translation.append(samples)

        # resume the ordering
        origin_order = np.argsort(result_numbers).tolist()
        translation = [translation[ii] for ii in origin_order]

        keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min(FLAGS.beam_size, FLAGS.keep_n)
        outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)]

        with batch_open(outputs, 'w') as handles:
            for trans in translation:
                for i in range(keep_n):
                    if i < len(trans):
                        handles[i].write('%s\n' % trans[i])
                    else:
                        handles[i].write('%s\n' % 'eos')
Exemple #6
0
def interactive_FBS(FLAGS):
    patience = FLAGS.try_times
    GlobalNames.USE_GPU = FLAGS.use_gpu
    config_path = os.path.abspath(FLAGS.config_path)

    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']

    timer = Timer()
    #===================================================================================
    #load data
    INFO('loading data...')
    timer.tic()

    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    valid_dataset = TextLineDataset(data_path=FLAGS.source_path,
                                    vocabulary=vocab_src)
    valid_iterator = DataIterator(dataset=valid_dataset,
                                  batch_size=FLAGS.batch_size,
                                  use_bucket=True,
                                  buffer_size=100000,
                                  numbering=True)

    valid_ref = []
    with open(FLAGS.ref_path) as f:
        for sent in f:
            valid_ref.append(vocab_tgt.sent2ids(sent))

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    #===================================================================================
    #build Model & Sampler & Validation
    INFO('Building model...')
    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        critic = critic.cuda()

    timer.tic()
    fw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                               n_tgt_vocab=vocab_tgt.max_n_words,
                               **model_configs)

    #bw_nmt_model = None
    bw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                               n_tgt_vocab=vocab_tgt.max_n_words,
                               **model_configs)
    fw_nmt_model.eval()
    bw_nmt_model.eval()

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Reloading model parameters...')
    timer.tic()

    fw_params = load_model_parameters(FLAGS.fw_model_path, map_location="cpu")
    bw_params = load_model_parameters(FLAGS.bw_model_path, map_location="cpu")

    fw_nmt_model.load_state_dict(fw_params)
    bw_nmt_model.load_state_dict(bw_params)

    if GlobalNames.USE_GPU:
        fw_nmt_model.cuda()
        bw_nmt_model.cuda()

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('begin...')
    timer.tic()

    result_numbers = []
    result = []
    n_words = 0

    imt_numbers = []
    imt_result = []
    imt_n_words = 0
    imt_constrains = [[] for ii in range(FLAGS.imt_step)]

    infer_progress_bar = tqdm(total=len(valid_iterator),
                              desc=' - (Infer)',
                              unit='sents')

    valid_iter = valid_iterator.build_generator()
    for batch in valid_iter:
        batch_result = []
        batch_numbers = []
        numbers, seqs_x = batch
        batch_size_t = len(seqs_x)

        x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU)

        with torch.no_grad():
            word_ids = beam_search(nmt_model=fw_nmt_model,
                                   beam_size=FLAGS.beam_size,
                                   max_steps=FLAGS.max_steps,
                                   src_seqs=x,
                                   alpha=FLAGS.alpha)

        word_ids = word_ids.cpu().numpy().tolist()

        for sent_t in word_ids:
            sent_t = [[wid for wid in line if wid != PAD] for line in sent_t]
            result.append(sent_t)
            batch_result.append(sent_t[0])

            n_words += len(sent_t[0])

        result_numbers += numbers
        imt_numbers += numbers
        batch_numbers += numbers
        batch_ref = [valid_ref[ii] for ii in batch_numbers]

        last_sents = copy.deepcopy(batch_result)
        constrains = [[[] for ii in range(patience)]
                      for jj in range(batch_size_t)]
        positions = [[[] for ii in range(patience)]
                     for jj in range(batch_size_t)]
        for idx in range(FLAGS.imt_step):
            cons, pos = sample_constrains(last_sents, batch_ref, patience)

            for ii in range(batch_size_t):
                for jj in range(patience):
                    constrains[ii][jj].append(cons[ii][jj])
                    positions[ii][jj].append(pos[ii][jj])

            #print(positions)
            imt_constrains[idx].append([vocab_tgt.ids2sent(c) for c in cons])
            bidirection = False
            if FLAGS.bidirection:
                bidirection = True
            with torch.no_grad():
                constrained_word_ids, positions = fixwords_beam_search(
                    fw_nmt_model=fw_nmt_model,
                    bw_nmt_model=bw_nmt_model,
                    beam_size=FLAGS.beam_size,
                    max_steps=FLAGS.max_steps,
                    src_seqs=x,
                    alpha=FLAGS.alpha,
                    constrains=constrains,
                    positions=positions,
                    last_sentences=last_sents,
                    imt_step=idx + 1,
                    bidirection=bidirection)
            constrained_word_ids = constrained_word_ids.cpu().numpy().tolist()
            last_sents = []
            for i, sent_t in enumerate(constrained_word_ids):
                sent_t = [[wid for wid in line if wid != PAD]
                          for line in sent_t]
                if idx == FLAGS.imt_step - 1:
                    imt_result.append(copy.deepcopy(sent_t))
                    imt_n_words += len(sent_t[0])
                samples = []
                for trans in sent_t:
                    sample = []
                    for w in trans:
                        if w == vocab_tgt.EOS:
                            break
                        sample.append(w)
                    samples.append(sample)

                sent_t = []
                for ii in range(len(samples)):
                    if ii % FLAGS.beam_size == 0:
                        sent_t.append(samples[ii])
                BLEU = []
                for sample in sent_t:
                    bleu, _ = bleuScore(sample, batch_ref[i])
                    BLEU.append(bleu)

                # print("BLEU: ", BLEU)
                order = np.argsort(BLEU).tolist()
                order = order[::-1]
                # print("order: ", order)
                sent_t = [sent_t[ii] for ii in order]

                last_sents.append(sent_t[0])

            if FLAGS.online_learning and idx == FLAGS.imt_step - 1:
                seqs_y = []
                for sent in last_sents:
                    sent = [BOS] + sent
                    seqs_y.append(sent)
                compute_forward(fw_nmt_model, critic, x,
                                torch.Tensor(seqs_y).long().cuda())
                seqs_y = [sent[::-1] for sent in seqs_y]
                for ii in range(len(seqs_y)):
                    seqs_y[ii][0] = BOS
                    seqs_y[ii][-1] = EOS
                compute_forward(bw_nmt_model, critic, x,
                                torch.Tensor(seqs_y).long().cuda())

        infer_progress_bar.update(batch_size_t)

    infer_progress_bar.close()
    INFO('Done. Speed: {0:.2f} words/sec'.format(
        n_words / (timer.toc(return_seconds=True))))

    translation = []
    for sent in result:
        samples = []
        for trans in sent:
            sample = []
            for w in trans:
                if w == vocab_tgt.EOS:
                    break
                sample.append(vocab_tgt.id2token(w))
            samples.append(vocab_tgt.tokenizer.detokenize(sample))
        translation.append(samples)

    origin_order = np.argsort(result_numbers).tolist()
    translation = [translation[ii] for ii in origin_order]

    keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min(
        FLAGS.beam_size, FLAGS.keep_n)
    outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)]

    with batch_open(outputs, 'w') as handles:
        for trans in translation:
            for i in range(keep_n):
                if i < len(trans):
                    handles[i].write('%s\n' % trans[i])
                else:
                    handles[i].write('%s\n' % 'eos')

    imt_translation = []
    for sent in imt_result:
        samples = []
        for trans in sent:
            sample = []
            for w in trans:
                if w == vocab_tgt.EOS:
                    break
                sample.append(w)
            samples.append(sample)
        imt_translation.append(samples)

    origin_order = np.argsort(imt_numbers).tolist()
    imt_translation = [imt_translation[ii] for ii in origin_order]
    for idx in range(FLAGS.imt_step):
        imt_constrains[idx] = [
            ' '.join(imt_constrains[idx][ii]) + '\n' for ii in origin_order
        ]

        with open('%s.cons%d' % (FLAGS.saveto, idx), 'w') as f:
            f.writelines(imt_constrains[idx])

    bleu_translation = []
    for idx, sent in enumerate(imt_translation):
        samples = []
        for ii in range(len(sent)):
            if ii % FLAGS.beam_size == 0:
                samples.append(sent[ii])
        BLEU = []
        for sample in samples:
            bleu, _ = bleuScore(sample, valid_ref[idx])
            BLEU.append(bleu)

        #print("BLEU: ", BLEU)
        order = np.argsort(BLEU).tolist()
        order = order[::-1]
        #print("order: ", order)
        samples = [vocab_tgt.ids2sent(samples[ii]) for ii in order]
        bleu_translation.append(samples)

    #keep_n = FLAGS.beam_size*patience if FLAGS.keep_n <= 0 else min(FLAGS.beam_size*patience, FLAGS.keep_n)
    keep_n = patience
    outputs = ['%s.imt%d' % (FLAGS.saveto, i) for i in range(keep_n)]

    with batch_open(outputs, 'w') as handles:
        for trans in bleu_translation:
            for i in range(keep_n):
                if i < len(trans):
                    handles[i].write('%s\n' % trans[i])
                else:
                    handles[i].write('%s\n' % 'eos')