Exemple #1
0
def ensemble_inference(valid_iterator,
                       models,
                       vocab_tgt: Vocabulary,
                       batch_size,
                       max_steps,
                       beam_size=5,
                       alpha=-1.0,
                       rank=0,
                       world_size=1,
                       using_numbering_iterator=True):
    for model in models:
        model.eval()

    trans_in_all_beams = [[] for _ in range(beam_size)]

    # assert keep_n_beams <= beam_size

    if using_numbering_iterator:
        numbers = []

    if rank == 0:
        infer_progress_bar = tqdm(total=len(valid_iterator),
                                  desc=' - (Infer)  ',
                                  unit="sents")
    else:
        infer_progress_bar = None

    valid_iter = valid_iterator.build_generator(batch_size=batch_size)

    for batch in valid_iter:

        seq_numbers = batch[0]

        if using_numbering_iterator:
            numbers += seq_numbers

        seqs_x = batch[1]

        if infer_progress_bar is not None:
            infer_progress_bar.update(len(seqs_x) * world_size)

        x = prepare_data(seqs_x, seqs_y=None, cuda=Constants.USE_GPU)

        with torch.no_grad():
            word_ids = ensemble_beam_search(nmt_models=models,
                                            beam_size=beam_size,
                                            max_steps=max_steps,
                                            src_seqs=x,
                                            alpha=alpha)

        word_ids = word_ids.cpu().numpy().tolist()

        # Append result
        for sent_t in word_ids:
            for ii, sent_ in enumerate(sent_t):
                sent_ = vocab_tgt.ids2sent(sent_)
                if sent_ == "":
                    sent_ = '%s' % vocab_tgt.id2token(vocab_tgt.eos)
                trans_in_all_beams[ii].append(sent_)

    if infer_progress_bar is not None:
        infer_progress_bar.close()

    if world_size > 1:
        if using_numbering_iterator:
            numbers = dist.all_gather_py_with_shared_fs(numbers)

        trans_in_all_beams = [
            combine_from_all_shards(trans) for trans in trans_in_all_beams
        ]

    if using_numbering_iterator:
        origin_order = np.argsort(numbers).tolist()
        trans_in_all_beams = [[trans[ii] for ii in origin_order]
                              for trans in trans_in_all_beams]

    return trans_in_all_beams
Exemple #2
0
def interactive_FBS(FLAGS):
    patience = FLAGS.try_times
    GlobalNames.USE_GPU = FLAGS.use_gpu
    config_path = os.path.abspath(FLAGS.config_path)

    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']

    timer = Timer()
    #===================================================================================
    #load data
    INFO('loading data...')
    timer.tic()

    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    valid_dataset = TextLineDataset(data_path=FLAGS.source_path,
                                    vocabulary=vocab_src)
    valid_iterator = DataIterator(dataset=valid_dataset,
                                  batch_size=FLAGS.batch_size,
                                  use_bucket=True,
                                  buffer_size=100000,
                                  numbering=True)

    valid_ref = []
    with open(FLAGS.ref_path) as f:
        for sent in f:
            valid_ref.append(vocab_tgt.sent2ids(sent))

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    #===================================================================================
    #build Model & Sampler & Validation
    INFO('Building model...')
    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        critic = critic.cuda()

    timer.tic()
    fw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                               n_tgt_vocab=vocab_tgt.max_n_words,
                               **model_configs)

    #bw_nmt_model = None
    bw_nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                               n_tgt_vocab=vocab_tgt.max_n_words,
                               **model_configs)
    fw_nmt_model.eval()
    bw_nmt_model.eval()

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('Reloading model parameters...')
    timer.tic()

    fw_params = load_model_parameters(FLAGS.fw_model_path, map_location="cpu")
    bw_params = load_model_parameters(FLAGS.bw_model_path, map_location="cpu")

    fw_nmt_model.load_state_dict(fw_params)
    bw_nmt_model.load_state_dict(bw_params)

    if GlobalNames.USE_GPU:
        fw_nmt_model.cuda()
        bw_nmt_model.cuda()

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    INFO('begin...')
    timer.tic()

    result_numbers = []
    result = []
    n_words = 0

    imt_numbers = []
    imt_result = []
    imt_n_words = 0
    imt_constrains = [[] for ii in range(FLAGS.imt_step)]

    infer_progress_bar = tqdm(total=len(valid_iterator),
                              desc=' - (Infer)',
                              unit='sents')

    valid_iter = valid_iterator.build_generator()
    for batch in valid_iter:
        batch_result = []
        batch_numbers = []
        numbers, seqs_x = batch
        batch_size_t = len(seqs_x)

        x = prepare_data(seqs_x=seqs_x, cuda=GlobalNames.USE_GPU)

        with torch.no_grad():
            word_ids = beam_search(nmt_model=fw_nmt_model,
                                   beam_size=FLAGS.beam_size,
                                   max_steps=FLAGS.max_steps,
                                   src_seqs=x,
                                   alpha=FLAGS.alpha)

        word_ids = word_ids.cpu().numpy().tolist()

        for sent_t in word_ids:
            sent_t = [[wid for wid in line if wid != PAD] for line in sent_t]
            result.append(sent_t)
            batch_result.append(sent_t[0])

            n_words += len(sent_t[0])

        result_numbers += numbers
        imt_numbers += numbers
        batch_numbers += numbers
        batch_ref = [valid_ref[ii] for ii in batch_numbers]

        last_sents = copy.deepcopy(batch_result)
        constrains = [[[] for ii in range(patience)]
                      for jj in range(batch_size_t)]
        positions = [[[] for ii in range(patience)]
                     for jj in range(batch_size_t)]
        for idx in range(FLAGS.imt_step):
            cons, pos = sample_constrains(last_sents, batch_ref, patience)

            for ii in range(batch_size_t):
                for jj in range(patience):
                    constrains[ii][jj].append(cons[ii][jj])
                    positions[ii][jj].append(pos[ii][jj])

            #print(positions)
            imt_constrains[idx].append([vocab_tgt.ids2sent(c) for c in cons])
            bidirection = False
            if FLAGS.bidirection:
                bidirection = True
            with torch.no_grad():
                constrained_word_ids, positions = fixwords_beam_search(
                    fw_nmt_model=fw_nmt_model,
                    bw_nmt_model=bw_nmt_model,
                    beam_size=FLAGS.beam_size,
                    max_steps=FLAGS.max_steps,
                    src_seqs=x,
                    alpha=FLAGS.alpha,
                    constrains=constrains,
                    positions=positions,
                    last_sentences=last_sents,
                    imt_step=idx + 1,
                    bidirection=bidirection)
            constrained_word_ids = constrained_word_ids.cpu().numpy().tolist()
            last_sents = []
            for i, sent_t in enumerate(constrained_word_ids):
                sent_t = [[wid for wid in line if wid != PAD]
                          for line in sent_t]
                if idx == FLAGS.imt_step - 1:
                    imt_result.append(copy.deepcopy(sent_t))
                    imt_n_words += len(sent_t[0])
                samples = []
                for trans in sent_t:
                    sample = []
                    for w in trans:
                        if w == vocab_tgt.EOS:
                            break
                        sample.append(w)
                    samples.append(sample)

                sent_t = []
                for ii in range(len(samples)):
                    if ii % FLAGS.beam_size == 0:
                        sent_t.append(samples[ii])
                BLEU = []
                for sample in sent_t:
                    bleu, _ = bleuScore(sample, batch_ref[i])
                    BLEU.append(bleu)

                # print("BLEU: ", BLEU)
                order = np.argsort(BLEU).tolist()
                order = order[::-1]
                # print("order: ", order)
                sent_t = [sent_t[ii] for ii in order]

                last_sents.append(sent_t[0])

            if FLAGS.online_learning and idx == FLAGS.imt_step - 1:
                seqs_y = []
                for sent in last_sents:
                    sent = [BOS] + sent
                    seqs_y.append(sent)
                compute_forward(fw_nmt_model, critic, x,
                                torch.Tensor(seqs_y).long().cuda())
                seqs_y = [sent[::-1] for sent in seqs_y]
                for ii in range(len(seqs_y)):
                    seqs_y[ii][0] = BOS
                    seqs_y[ii][-1] = EOS
                compute_forward(bw_nmt_model, critic, x,
                                torch.Tensor(seqs_y).long().cuda())

        infer_progress_bar.update(batch_size_t)

    infer_progress_bar.close()
    INFO('Done. Speed: {0:.2f} words/sec'.format(
        n_words / (timer.toc(return_seconds=True))))

    translation = []
    for sent in result:
        samples = []
        for trans in sent:
            sample = []
            for w in trans:
                if w == vocab_tgt.EOS:
                    break
                sample.append(vocab_tgt.id2token(w))
            samples.append(vocab_tgt.tokenizer.detokenize(sample))
        translation.append(samples)

    origin_order = np.argsort(result_numbers).tolist()
    translation = [translation[ii] for ii in origin_order]

    keep_n = FLAGS.beam_size if FLAGS.keep_n <= 0 else min(
        FLAGS.beam_size, FLAGS.keep_n)
    outputs = ['%s.%d' % (FLAGS.saveto, i) for i in range(keep_n)]

    with batch_open(outputs, 'w') as handles:
        for trans in translation:
            for i in range(keep_n):
                if i < len(trans):
                    handles[i].write('%s\n' % trans[i])
                else:
                    handles[i].write('%s\n' % 'eos')

    imt_translation = []
    for sent in imt_result:
        samples = []
        for trans in sent:
            sample = []
            for w in trans:
                if w == vocab_tgt.EOS:
                    break
                sample.append(w)
            samples.append(sample)
        imt_translation.append(samples)

    origin_order = np.argsort(imt_numbers).tolist()
    imt_translation = [imt_translation[ii] for ii in origin_order]
    for idx in range(FLAGS.imt_step):
        imt_constrains[idx] = [
            ' '.join(imt_constrains[idx][ii]) + '\n' for ii in origin_order
        ]

        with open('%s.cons%d' % (FLAGS.saveto, idx), 'w') as f:
            f.writelines(imt_constrains[idx])

    bleu_translation = []
    for idx, sent in enumerate(imt_translation):
        samples = []
        for ii in range(len(sent)):
            if ii % FLAGS.beam_size == 0:
                samples.append(sent[ii])
        BLEU = []
        for sample in samples:
            bleu, _ = bleuScore(sample, valid_ref[idx])
            BLEU.append(bleu)

        #print("BLEU: ", BLEU)
        order = np.argsort(BLEU).tolist()
        order = order[::-1]
        #print("order: ", order)
        samples = [vocab_tgt.ids2sent(samples[ii]) for ii in order]
        bleu_translation.append(samples)

    #keep_n = FLAGS.beam_size*patience if FLAGS.keep_n <= 0 else min(FLAGS.beam_size*patience, FLAGS.keep_n)
    keep_n = patience
    outputs = ['%s.imt%d' % (FLAGS.saveto, i) for i in range(keep_n)]

    with batch_open(outputs, 'w') as handles:
        for trans in bleu_translation:
            for i in range(keep_n):
                if i < len(trans):
                    handles[i].write('%s\n' % trans[i])
                else:
                    handles[i].write('%s\n' % 'eos')