def sample_from_posterior(model, sentences_x, vocab_src, vocab_tgt, config):
    num_samples = 5
    sentences_x = np.tile(sentences_x, 5)
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")

    x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src, device)
    x_mask = x_mask.unsqueeze(1)

    qz = model.inference(x_in, x_mask, x_len)
    z = qz.sample()

    enc_output, enc_hidden = model.encode(x_in, x_len, z)
    dec_hidden = model.init_decoder(enc_output, enc_hidden, z)

    y_samples = ancestral_sample(model.decoder,
                                 model.emb_tgt,
                                 model.generate_tm,
                                 enc_output,
                                 dec_hidden,
                                 x_mask,
                                 vocab_tgt[SOS_TOKEN],
                                 vocab_tgt[EOS_TOKEN],
                                 vocab_tgt[PAD_TOKEN],
                                 config,
                                 greedy=True,
                                 z=z)

    y_samples = batch_to_sentences(y_samples, vocab_tgt)
    print("Sample translations from the approximate posterior")
    for idx, y in enumerate(y_samples, 1):
        print("{}: {}".format(idx, y))
def sample_from_latent(model, vocab_src, vocab_tgt, config):
    num_samples = 5

    prior = torch.distributions.Normal(loc=model.prior_loc,
                                       scale=model.prior_scale)
    z = prior.sample(sample_shape=[num_samples])

    hidden_lm = tile_rnn_hidden(model.lm_init_layer(z),
                                model.language_model.rnn)
    x_init = z.new([vocab_tgt[SOS_TOKEN] for _ in range(num_samples)]).long()
    x_embed = model.emb_src(x_init).unsqueeze(1)

    x_samples = [x_init.unsqueeze(-1)]

    for _ in range(config["max_len"]):
        pre_output, hidden_lm = model.language_model.forward_step(
            x_embed, hidden_lm, z)
        logits = model.generate_lm(pre_output)
        next_word_dist = torch.distributions.categorical.Categorical(
            logits=logits)
        x = next_word_dist.sample()
        x_embed = model.emb_src(x)
        x_samples.append(x)

    x_samples = torch.cat(x_samples, dim=-1)
    x_samples = batch_to_sentences(x_samples, vocab_src)

    print("Sampled source sentences from the latent space ")
    for idx, x in enumerate(x_samples, 1):
        print("{}: {}".format(idx, x))
def validate(model,
             dev_data,
             vocab_src,
             vocab_tgt,
             epoch,
             config,
             direction=None):
    model.eval()
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        model_hypotheses = []
        references = []

        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=4)
        val_dl = BucketingParallelDataLoader(val_dl)
        for sentences_x, sentences_y in val_dl:
            if direction == None or direction == "xy":
                x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)
            else:
                x_in, _, x_mask, x_len = create_batch(sentences_y, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)

            enc_output, enc_hidden = model.encode(x_in, x_len)
            dec_hidden = model.init_decoder(enc_output, enc_hidden)

            raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                         model.generate_tm, enc_output,
                                         dec_hidden, x_mask, vocab_tgt.size(),
                                         vocab_tgt[SOS_TOKEN],
                                         vocab_tgt[EOS_TOKEN],
                                         vocab_tgt[PAD_TOKEN], config)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
            model_hypotheses += hypothesis.tolist()

            if direction == None or direction == "xy":
                references += sentences_y.tolist()
            else:
                references += sentences_x.tolist()

        save_hypotheses(model_hypotheses, epoch, config)
        model_hypotheses, references = clean_sentences(model_hypotheses,
                                                       references, config)
        bleu = compute_bleu(model_hypotheses, references, epoch, config,
                            direction)
        return bleu
示例#4
0
def mono_train_fn(model_xy, model_yx, y_in, y_noisy_in, y_len, y_mask, y_out,
                  vocab_src, config, step):
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        qz_y = model_yx.inference(y_in, y_mask, y_len)
        z_y = qz_y.sample()

        enc_output, enc_final = model_yx.encode(y_in, y_len, z_y)
        dec_hidden = model_yx.init_decoder(enc_output, enc_final, z_y)

        if config["decoding_method"] == "beam_search":
            x_samples = beam_search(model_yx.decoder,
                                    model_yx.emb_tgt,
                                    model_yx.generate_tm,
                                    enc_output,
                                    dec_hidden,
                                    y_mask,
                                    vocab_src.size(),
                                    vocab_src[SOS_TOKEN],
                                    vocab_src[EOS_TOKEN],
                                    vocab_src[PAD_TOKEN],
                                    config,
                                    beam_width=config["decoding_beam_width"],
                                    z=z_y)
        else:
            greedy = False if config["decoding_method"] == "ancestral" else True
            x_samples = ancestral_sample(model_yx.decoder,
                                         model_yx.emb_tgt,
                                         model_yx.generate_tm,
                                         enc_output,
                                         dec_hidden,
                                         y_mask,
                                         vocab_src[SOS_TOKEN],
                                         vocab_src[EOS_TOKEN],
                                         vocab_src[PAD_TOKEN],
                                         config,
                                         greedy=greedy,
                                         z=z_y)
        x_samples = batch_to_sentences(x_samples, vocab_src)
        x_in, x_out, x_mask, x_len, x_noisy_in = create_noisy_batch(
            x_samples, vocab_src, device, word_dropout=config["word_dropout"])
        x_mask = x_mask.unsqueeze(1)

    qz_x = model_xy.inference(x_in, x_mask, x_len)
    z_x = qz_x.rsample()
    tm_logits, lm_logits, _, _ = model_xy(x_noisy_in, x_len, x_mask,
                                          y_noisy_in, z_x)

    loss = model_xy.loss(tm_logits, lm_logits, None, None, y_out, x_out, qz_x,
                         step)
    return loss
def mono_train_fn(model_xy, model_yx, y_in, y_noisy_in, y_len, y_mask, y_out,
                  vocab_src, config, step):
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        enc_output, enc_final = model_yx.encode(y_in, y_len)
        dec_hidden = model_yx.init_decoder(enc_output, enc_final)

        if config["decoding_method"] == "beam_search":
            x_samples = beam_search(model_yx.decoder,
                                    model_yx.emb_tgt,
                                    model_yx.generate_tm,
                                    enc_output,
                                    dec_hidden,
                                    y_mask,
                                    vocab_src.size(),
                                    vocab_src[SOS_TOKEN],
                                    vocab_src[EOS_TOKEN],
                                    vocab_src[PAD_TOKEN],
                                    config,
                                    beam_width=config["decoding_beam_width"])
        else:
            greedy = False if config["decoding_method"] == "ancestral" else True
            x_samples = ancestral_sample(model_yx.decoder,
                                         model_yx.emb_tgt,
                                         model_yx.generate_tm,
                                         enc_output,
                                         dec_hidden,
                                         y_mask,
                                         vocab_src[SOS_TOKEN],
                                         vocab_src[EOS_TOKEN],
                                         vocab_src[PAD_TOKEN],
                                         config,
                                         greedy=greedy)
        x_samples = batch_to_sentences(x_samples, vocab_src)
        x_in, x_out, x_mask, x_len = create_batch(x_samples, vocab_src, device)
        x_mask = x_mask.unsqueeze(1)
    logits = model_xy(x_in, x_mask, x_len, y_in)
    loss = model_xy.loss(logits, y_out)
    return loss
def validate(model,
             dev_data,
             vocab_src,
             vocab_tgt,
             epoch,
             config,
             direction=None):
    model.eval()
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        model_hypotheses = []
        references = []

        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=2)
        val_dl = BucketingParallelDataLoader(val_dl)
        val_kl = 0
        for sentences_x, sentences_y in val_dl:
            if direction == None or direction == "xy":
                x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)
            else:
                x_in, _, x_mask, x_len = create_batch(sentences_y, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)

            qz = model.inference(x_in, x_mask, x_len)
            z = qz.mean

            pz = torch.distributions.Normal(loc=model.prior_loc,
                                            scale=model.prior_scale).expand(
                                                qz.mean.size())
            kl_loss = torch.distributions.kl.kl_divergence(qz, pz)
            kl_loss = kl_loss.sum(dim=1)
            val_kl += kl_loss.sum(dim=0)

            enc_output, enc_hidden = model.encode(x_in, x_len, z)
            dec_hidden = model.init_decoder(enc_output, enc_hidden, z)

            raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                         model.generate_tm, enc_output,
                                         dec_hidden, x_mask, vocab_tgt.size(),
                                         vocab_tgt[SOS_TOKEN],
                                         vocab_tgt[EOS_TOKEN],
                                         vocab_tgt[PAD_TOKEN], config, z)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
            model_hypotheses += hypothesis.tolist()

            if direction == None or direction == "xy":
                references += sentences_y.tolist()
            else:
                references += sentences_x.tolist()

        val_kl /= len(dev_data)
        save_hypotheses(model_hypotheses, epoch, config, direction)
        model_hypotheses, references = clean_sentences(model_hypotheses,
                                                       references, config)
        bleu = compute_bleu(model_hypotheses,
                            references,
                            epoch,
                            config,
                            direction,
                            kl=val_kl)
        return bleu
def main():
    config = setup_config()
    config["dev_prefix"] = "comparable"
    vocab_src, vocab_tgt = load_vocabularies(config)
    _, dev_data, _ = load_data(config,
                               vocab_src=vocab_src,
                               vocab_tgt=vocab_tgt)

    model, _, validate_fn = create_model(vocab_src, vocab_tgt, config)
    model.to(torch.device(config["device"]))

    checkpoint_path = "{}/cond_nmt_de-en_run_7/checkpoints/cond_nmt_de-en_run_7".format(
        config["out_dir"])

    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])

    model.eval()
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        model_hypotheses = []
        references = []

        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=4)
        # val_dl = BucketingParallelDataLoader(val_dl)
        for sentences_x, sentences_y in tqdm(val_dl):

            sentences_x = np.array(sentences_x)
            seq_len = np.array([len(s.split()) for s in sentences_x])
            sort_keys = np.argsort(-seq_len)
            sentences_x = sentences_x[sort_keys]
            # #
            sentences_y = np.array(sentences_y)

            x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                  device)
            x_mask = x_mask.unsqueeze(1)

            if config["model_type"] == "aevnmt":
                qz = model.inference(x_in, x_mask, x_len)
                z = qz.mean

                enc_output, enc_hidden = model.encode(x_in, x_len, z)
                dec_hidden = model.init_decoder(enc_output, enc_hidden, z)

                raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                             model.generate_tm, enc_output,
                                             dec_hidden, x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN], config)
            else:
                enc_output, enc_hidden = model.encode(x_in, x_len)
                dec_hidden = model.decoder.initialize(enc_output, enc_hidden)

                raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                             model.generate_tm, enc_output,
                                             dec_hidden, x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN], config)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)

            inverse_sort_keys = np.argsort(sort_keys)
            model_hypotheses += hypothesis[inverse_sort_keys].tolist()

            references += sentences_y.tolist()
        save_hypotheses(model_hypotheses, 0, config, None)
        model_hypotheses, references = clean_sentences(model_hypotheses,
                                                       references, config)
        bleu = sacrebleu.raw_corpus_bleu(model_hypotheses, [references]).score
        print(bleu)
示例#8
0
def main():
    config = setup_config()
    config["dev_prefix"] = "comparable"
    vocab_src, vocab_tgt = load_vocabularies(config)
    _, dev_data, _ = load_data(config,
                               vocab_src=vocab_src,
                               vocab_tgt=vocab_tgt)

    # _, dev_data, vocab_src, vocab_tgt = load_dataset_joey(config)
    model, _, validate_fn = create_model(vocab_src, vocab_tgt, config)
    model.to(torch.device(config["device"]))

    checkpoint_path = "{}/cond_nmt_new_de-en_run_2/checkpoints/cond_nmt_new_de-en_run_2".format(
        config["out_dir"])
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])

    model.eval()
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        model_hypotheses = []
        references = []

        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=4)
        val_dl = BucketingParallelDataLoader(val_dl)
        for sentences_x, sentences_y in tqdm(val_dl):
            x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                  device)
            x_mask = x_mask.unsqueeze(1)

            if config["model_type"] == "aevnmt":
                qz = model.inference(x_in, x_mask)
                z = qz.mean

                enc_output, enc_hidden = model.encode(x_in, z)
                dec_hidden = model.init_decoder(enc_output, enc_hidden, z)

                raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                             model.generate_tm, enc_output,
                                             dec_hidden, x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN], config)
            else:
                enc_output, enc_hidden = model.encode(x_in)
                dec_hidden = model.decoder.initialize(enc_output, enc_hidden)

                raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                             model.generate, enc_output,
                                             dec_hidden, x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN], config)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
            model_hypotheses += hypothesis.tolist()

            references += sentences_y.tolist()

        save_hypotheses(model_hypotheses, 0, config, None)
示例#9
0
def evaluate(model, dev_data, vocab_src, vocab_tgt, config, direction=None):
    model.eval()
    with torch.no_grad():
        model_hypotheses = []
        references = []

        device = torch.device(
            "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
        val_dl = DataLoader(dev_data,
                            batch_size=config["batch_size_eval"],
                            shuffle=False,
                            num_workers=4)
        # val_dl = BucketingParallelDataLoader(val_dl)
        for sentences_x, sentences_y in tqdm(val_dl):
            if direction == None or direction == "xy":
                sentences_x, sentences_y, sort_keys = sort_sentences(
                    sentences_x, sentences_y)
                x_in, _, x_mask, x_len = create_batch(sentences_x, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)
            else:
                sentences_y, sentences_x, sort_keys = sort_sentences(
                    sentences_y, sentences_x)
                x_in, _, x_mask, x_len = create_batch(sentences_y, vocab_src,
                                                      device)
                x_mask = x_mask.unsqueeze(1)

            if config["model_type"] == "coaevnmt":
                qz = model.inference(x_in, x_mask, x_len)
                z = qz.mean

                enc_output, enc_hidden = model.encode(x_in, x_len, z)
                dec_hidden = model.init_decoder(enc_output, enc_hidden, z)

                raw_hypothesis = beam_search(model.decoder,
                                             model.emb_tgt,
                                             model.generate_tm,
                                             enc_output,
                                             dec_hidden,
                                             x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN],
                                             config,
                                             z=z)
            elif config["model_type"] == "conmt":
                enc_output, enc_hidden = model.encode(x_in, x_len)
                dec_hidden = model.decoder.initialize(enc_output, enc_hidden)

                raw_hypothesis = beam_search(model.decoder, model.emb_tgt,
                                             model.generate_tm, enc_output,
                                             dec_hidden, x_mask,
                                             vocab_tgt.size(),
                                             vocab_tgt[SOS_TOKEN],
                                             vocab_tgt[EOS_TOKEN],
                                             vocab_tgt[PAD_TOKEN], config)

            hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
            inverse_sort_keys = np.argsort(sort_keys)
            model_hypotheses += hypothesis[inverse_sort_keys].tolist()

            if direction == None or direction == "xy":
                references += sentences_y.tolist()
            else:
                references += sentences_x.tolist()

        model_hypotheses, references = clean_sentences(model_hypotheses,
                                                       references, config)
        bleu = sacrebleu.raw_corpus_bleu(model_hypotheses, [references]).score
        print(bleu)