示例#1
0
def validate(model,
             val_data,
             vocab_src,
             vocab_tgt,
             device,
             hparams,
             step,
             summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data,
                        batch_size=hparams.batch_size,
                        shuffle=False,
                        num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_NLL = _evaluate_perplexity(model, val_dl, vocab_src,
                                            vocab_tgt, device)
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src,
                                                  vocab_tgt, device, hparams)

    random_idx = np.random.choice(len(inputs))
    print(f"validation perplexity = {val_ppl:,.2f}"
          f" -- validation NLL = {val_NLL:,.2f}"
          f" -- validation BLEU = {val_bleu:.2f}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar("validation/NLL", val_NLL, step)
        summary_writer.add_scalar("validation/BLEU", val_bleu, step)
        summary_writer.add_scalar("validation/perplexity", val_ppl, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x],
                                                          vocab_src, device)
            y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt,
                                             device)
            _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in)
            att_weights = att_weights.squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src,
                                        no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt,
                                        no_filter=True)[0].split()
        attention_summary(tgt_labels, src_labels, att_weights, summary_writer,
                          "validation/attention", step)

    return {
        'bleu': val_bleu,
        'likelihood': -val_NLL,
        'nll': val_NLL,
        'ppl': val_ppl
    }
示例#2
0
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams):
    model.eval()
    with torch.no_grad():
        x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences,
                                                      vocab_src, device)
        encoder_outputs, encoder_final = model.encode(x_in, seq_len_x)
        hidden = model.init_decoder(encoder_outputs, encoder_final)
        if hparams.sample_decoding:
            raw_hypothesis = sampling_decode(
                model.decoder, model.tgt_embed, model.generate, hidden,
                encoder_outputs, encoder_final, seq_mask_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length)
        elif hparams.beam_width <= 1:
            raw_hypothesis = greedy_decode(
                model.decoder, model.tgt_embed, model.generate, hidden,
                encoder_outputs, encoder_final, seq_mask_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length)
        else:
            raw_hypothesis = beam_search(
                model.decoder, model.tgt_embed, model.generate,
                vocab_tgt.size(), hidden, encoder_outputs, encoder_final,
                seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.beam_width,
                hparams.length_penalty_factor, hparams.max_decoding_length)
    hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
    return hypothesis
示例#3
0
def translate(model,
              input_sentences,
              vocab_src,
              vocab_tgt,
              device,
              hparams,
              deterministic=True):
    model.eval()
    with torch.no_grad():
        x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences,
                                                      vocab_src, device)

        # For translation we use the approximate posterior mean.
        qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x)
        z = qz.mean if deterministic else qz.sample()

        encoder_outputs, encoder_final = model.encode(x_in, seq_len_x, z)
        hidden = model.init_decoder(encoder_outputs, encoder_final, z)
        if hparams.beam_width <= 1:
            raw_hypothesis = greedy_decode(
                model.decoder, model.tgt_embed, model.generate, hidden,
                encoder_outputs, encoder_final, seq_mask_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length)
        else:
            raw_hypothesis = beam_search(
                model.decoder, model.tgt_embed, model.generate,
                vocab_tgt.size(), hidden, encoder_outputs, encoder_final,
                seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.beam_width,
                hparams.length_penalty_factor, hparams.max_decoding_length)
    hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
    return hypothesis
示例#4
0
def generate_senvae(model,
                    input_sentences,
                    num_samples,
                    vocab_src,
                    device,
                    hparams,
                    deterministic=True):

    model.eval()
    with torch.no_grad():

        if input_sentences is not None:
            x_in, _, seq_mask_x, seq_len_x = create_batch(
                input_sentences, vocab_src, device)
            qz = model.approximate_posterior(x_in,
                                             seq_mask_x,
                                             seq_len_x,
                                             y=x_in,
                                             seq_mask_y=seq_mask_x,
                                             seq_len_y=seq_len_x)
        else:
            qz = model.prior().expand((num_samples, ))

        # TODO: restore some form of deterministic decoding
        #z = qz.mean if deterministic else qz.sample()
        z = qz.sample()

        if isinstance(model.language_model, TransformerLM):
            hidden = None
        else:
            hidden = model.language_model.init(z)

        if hparams.decoding.sample:
            raw_hypothesis = model.language_model.sample(
                z, max_len=hparams.decoding.max_length, greedy=False)

        elif hparams.decoding.beam_width <= 1:
            raw_hypothesis = model.language_model.sample(
                z, max_len=hparams.decoding.max_length, greedy=True)

        else:
            raise NotImplementedError
            """
            raw_hypothesis = beam_search(
                model.language_model,
                model.language_model.embedder,
                model.language_model.generate,
                vocab_src.size(), None, None,
                None, None, None,
                vocab_src[SOS_TOKEN], vocab_src[EOS_TOKEN],
                vocab_src[PAD_TOKEN], hparams.decoding.beam_width,
                hparams.decoding.length_penalty_factor,
                hparams.decoding.max_length,
                z)
            """

    hypothesis = batch_to_sentences(raw_hypothesis, vocab_src)
    return hypothesis
示例#5
0
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams, deterministic=True):
    # TODO: this code should be in the translation model class
    model.eval()
    with torch.no_grad():
        x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences, vocab_src, device)

        # For translation we use the approximate posterior mean.
        qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x,
                y=x_in, seq_mask_y=seq_mask_x, seq_len_y=seq_len_x) # TODO: here we need a prediction net!
        # TODO: restore some form of deterministic decoding
        #z = qz.mean if deterministic else qz.sample()
        z = qz.sample()

        encoder_outputs, encoder_final = model.translation_model.encode(x_in, seq_len_x, z)
        hidden = model.translation_model.init_decoder(encoder_outputs, encoder_final, z)

        if hparams.sample_decoding:
            # TODO: we could use the new version below
            #raw_hypothesis = model.translation_model.sample(x_in, seq_mask_x, seq_len_x, z, 
            #    max_len=hparams.max_decoding_length, greedy=False)
            raw_hypothesis = sampling_decode(
                model.translation_model.decoder, 
                model.translation_model.tgt_embed,
                model.translation_model.generate, hidden,
                encoder_outputs, encoder_final,
                seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length,
                z if hparams.feed_z else None)
        elif hparams.beam_width <= 1:
            # TODO: we could use the new version below
            #raw_hypothesis = model.translation_model.sample(x_in, seq_mask_x, seq_len_x, z, 
            #    max_len=hparams.max_decoding_length, greedy=True)
            raw_hypothesis = greedy_decode(
                model.translation_model.decoder, 
                model.translation_model.tgt_embed,
                model.translation_model.generate, hidden,
                encoder_outputs, encoder_final,
                seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length,
                z if hparams.feed_z else None)
        else:
            raw_hypothesis = beam_search(
                model.translation_model.decoder, 
                model.translation_model.tgt_embed, 
                model.translation_model.generate,
                vocab_tgt.size(), hidden, encoder_outputs,
                encoder_final, seq_mask_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.beam_width,
                hparams.length_penalty_factor,
                hparams.max_decoding_length,
                z if hparams.feed_z else None)

    hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
    return hypothesis
示例#6
0
def _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device):
    model.eval()
    with torch.no_grad():
        num_predictions = 0
        num_sentences = 0
        val_NLL = 0.
        for sentences_x, sentences_y in val_dl:
            x_in, _, seq_mask_x, seq_len_x = create_batch(sentences_x, vocab_src, device)
            y_in, y_out, _, seq_len_y = create_batch(sentences_y, vocab_tgt, device)

            # Do a forward pass and compute the validation loss of this batch.
            logits, _ = model(x_in, seq_mask_x, seq_len_x, y_in)
            batch_NLL = model.loss(logits, y_out, reduction="sum")["loss"]
            val_NLL += batch_NLL.item()

            num_sentences += x_in.size(0)
            num_predictions += seq_len_y.sum().item()

    val_perplexity = np.exp(val_NLL / num_predictions)
    return val_perplexity, val_NLL/num_sentences
示例#7
0
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams):
    model.eval()
    with torch.no_grad():
        x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences,
                                                      vocab_src, device)
        if isinstance(model.translation_model, TransformerTM):
            encoder_outputs, seq_len_x = model.translation_model.encode(
                x_in, seq_len_x, None)
            encoder_final = None
            hidden = None
        else:
            encoder_outputs, encoder_final = model.translation_model.encode(
                x_in, seq_len_x, None)
            hidden = model.translation_model.init_decoder(
                encoder_outputs, encoder_final, None)
        if hparams.decoding.sample:
            raw_hypothesis = model.translation_model.sample(
                x_in,
                seq_mask_x,
                seq_len_x,
                None,
                max_len=hparams.decoding.max_length,
                greedy=False)
        elif hparams.decoding.beam_width <= 1:
            raw_hypothesis = raw_hypothesis = model.translation_model.sample(
                x_in,
                seq_mask_x,
                seq_len_x,
                None,
                max_len=hparams.decoding.max_length,
                greedy=True)
        else:
            raw_hypothesis = beam_search(
                model.translation_model.decoder,
                model.translation_model.tgt_embed,
                model.translation_model.generate, vocab_tgt.size(), hidden,
                encoder_outputs, encoder_final, seq_mask_x, seq_len_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.decoding.beam_width,
                hparams.decoding.length_penalty_factor,
                hparams.decoding.max_length, None)

    hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
    return hypothesis
示例#8
0
def _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device):
    model.eval()
    with torch.no_grad():
        num_predictions = 0
        num_sentences = 0
        log_marginal = defaultdict(float)
        total_KL = 0.
        n_samples = 10
        for sentences_x, sentences_y in val_dl:
            x_in, x_out, seq_mask_x, seq_len_x = create_batch(sentences_x, vocab_src, device)
            y_in, y_out, seq_mask_y, seq_len_y = create_batch(sentences_y, vocab_tgt, device)

            # Infer q(z|x) for this batch.
            qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y)
            pz = model.prior()  
            if isinstance(qz, ProductOfDistributions):
                total_KL += torch.cat(
                    [kl_divergence(qi, pi).sum(0).unsqueeze(-1) for qi, pi in zip(qz.distributions, pz.distributions)], 
                    -1)
            else:
                total_KL += kl_divergence(qz, pz).sum(0)

            # Take s importance samples from q(z|x):
            # log int{p(x, y, z) dz} ~= log sum_z{p(x, y, z) / q(z|x)} where z ~ q(z|x)
            batch_size = x_in.size(0)
            batch_log_marginals = defaultdict(lambda: torch.zeros(n_samples, batch_size))

            for s in range(n_samples):

                # z ~ q(z|x)
                z = qz.sample()

                # Compute the logits according to this sample of z.
                tm_likelihood, lm_likelihood, _, aux_lm_likelihoods, aux_tm_likelihoods = model(x_in, seq_mask_x, seq_len_x, y_in, z)

                # Compute log P(y|x, z_s)
                log_tm_prob = model.translation_model.log_prob(tm_likelihood, y_out)

                # Compute log P(x|z_s)
                log_lm_prob = model.language_model.log_prob(lm_likelihood, x_out)
                
                # Compute prior probability log P(z_s) and importance weight q(z_s|x)
                log_pz = pz.log_prob(z) 
                log_qz = qz.log_prob(z)

                # Estimate the importance weighted estimate of (the log of) P(x, y)
                batch_log_marginals['joint/main'][s] = log_tm_prob + log_lm_prob + log_pz - log_qz
                batch_log_marginals['lm/main'][s] = log_lm_prob + log_pz - log_qz
                batch_log_marginals['tm/main'][s] = log_tm_prob + log_pz - log_qz
                
                for aux_comp, aux_px_z in aux_lm_likelihoods.items():
                    batch_log_marginals['lm/' + aux_comp][s] = model.log_likelihood_lm(aux_comp, aux_px_z, x_out) + log_pz - log_qz
                for aux_comp, aux_py_xz in aux_tm_likelihoods.items():
                    batch_log_marginals['tm/' + aux_comp][s] = model.log_likelihood_tm(aux_comp, aux_py_xz, y_out) + log_pz - log_qz

            for comp_name, log_marginals in batch_log_marginals.items():
                # Average over all samples.
                batch_avg = torch.logsumexp(log_marginals, dim=0) - torch.log(torch.Tensor([n_samples]))
                log_marginal[comp_name] = log_marginal[comp_name] + batch_avg.sum().item()

            num_sentences += batch_size
            num_predictions += (seq_len_x.sum() + seq_len_y.sum()).item()

    val_NLL = -log_marginal['joint/main']
    val_perplexity = np.exp(val_NLL / num_predictions)

    NLLs = {comp_name: -value / num_sentences for comp_name, value in log_marginal.items()}

    return val_perplexity, total_KL/num_sentences, NLLs
示例#9
0
def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, title='xy', summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data, batch_size=hparams.batch_size,
                        shuffle=False, num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_KL, val_NLLs = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device)
    val_NLL = val_NLLs['joint/main']
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt,
                                                  device, hparams)

    random_idx = np.random.choice(len(inputs))
    #nll_str = ' '.join('-- validation NLL {} = {:.2f}'.format(comp_name, comp_value)  for comp_name, comp_value in sorted(val_NLLs.items()))
    nll_str = f""
    # - log P(x|z) for the various source LM decoders
    for comp_name, comp_nll in sorted(val_NLLs.items()):
        if comp_name.startswith('lm/'):
            nll_str += f" -- {comp_name} = {comp_nll:,.2f}"
    # - log P(y|z,x) for the various translation decoders
    for comp_name, comp_nll in sorted(val_NLLs.items()):
        if comp_name.startswith('tm/'):
            nll_str += f" -- {comp_name} = {comp_nll:,.2f}"
    
    kl_str = f"-- KL = {val_KL.sum():.2f}"
    if isinstance(model.prior(), ProductOfDistributions):
        for i, p in enumerate(model.prior().distributions):
            kl_str += f" -- KL{i} = {val_KL[i]:.2f}" 
        
    print(f"direction = {title}\n"
          f"validation perplexity = {val_ppl:,.2f}"
          f" -- BLEU = {val_bleu:.2f}"
          f" {kl_str}"
          f" {nll_str}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    if hparams.draw_translations > 0:
        random_idx = np.random.choice(len(inputs))
        dl = DataLoader([val_data[random_idx] for _ in range(hparams.draw_translations)], batch_size=hparams.batch_size, shuffle=False, num_workers=4)
        dl = BucketingParallelDataLoader(dl)
        i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device, hparams)
        print("Posterior samples")
        print(f"- Input: {i[0]}")
        print(f"- Reference: {r[0]}")
        for h in hs:
            print(f"- Translation: {h}")
    
    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step)
        summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl, step)
        summary_writer.add_scalar(f"{title}/validation/KL", val_KL.sum(), step)
        if isinstance(model.prior(), ProductOfDistributions):
            for i, _ in enumerate(model.prior().distributions):
                summary_writer.add_scalar(f"{title}/validation/KL{i}", val_KL[i], step)
        for comp_name, comp_value in val_NLLs.items():
            summary_writer.add_scalar(f"{title}/validation/NLL/{comp_name}", comp_value, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device)
            y_in, y_out, seq_mask_y, seq_len_y = create_batch([val_sentence_y], vocab_tgt, device)
            z = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y).sample()
            _, _, state, _, _ = model(x_in, seq_mask_x, seq_len_x, y_in, z)
            att_weights = state['att_weights'].squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split()
        attention_summary(src_labels, tgt_labels, att_weights, summary_writer,
                          f"{title}/validation/attention", step)

    return {'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl}
示例#10
0
def validate(model,
             val_data,
             vocab_src,
             vocab_tgt,
             device,
             hparams,
             step,
             title='xy',
             summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data,
                        batch_size=hparams.batch_size,
                        shuffle=False,
                        num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_NLL, val_KL = _evaluate_perplexity(model, val_dl, vocab_src,
                                                    vocab_tgt, device)
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src,
                                                  vocab_tgt, device, hparams)

    random_idx = np.random.choice(len(inputs))
    print(f"direction = {title}\n"
          f"validation perplexity = {val_ppl:,.2f}"
          f" -- validation NLL = {val_NLL:,.2f}"
          f" -- validation BLEU = {val_bleu:.2f}"
          f" -- validation KL = {val_KL:.2f}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    if hparams.draw_translations > 0:
        random_idx = np.random.choice(len(inputs))
        dl = DataLoader(
            [val_data[random_idx] for _ in range(hparams.draw_translations)],
            batch_size=hparams.batch_size,
            shuffle=False,
            num_workers=4)
        dl = BucketingParallelDataLoader(dl)
        i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device,
                                      hparams)
        print("Posterior samples")
        print(f"- Input: {i[0]}")
        print(f"- Reference: {r[0]}")
        for h in hs:
            print(f"- Translation: {h}")

    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar(f"{title}/validation/NLL", val_NLL, step)
        summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step)
        summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl,
                                  step)
        summary_writer.add_scalar(f"{title}/validation/KL", val_KL, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x],
                                                          vocab_src, device)
            y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt,
                                             device)
            z = model.approximate_posterior(x_in, seq_mask_x,
                                            seq_len_x).sample()
            _, _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in, z)
            att_weights = att_weights.squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src,
                                        no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt,
                                        no_filter=True)[0].split()
        attention_summary(src_labels, tgt_labels, att_weights, summary_writer,
                          f"{title}/validation/attention", step)

    return {
        'bleu': val_bleu,
        'likelihood': -val_NLL,
        'nll': val_NLL,
        'ppl': val_ppl
    }
示例#11
0
def _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device):
    model.eval()
    with torch.no_grad():
        num_predictions = 0
        num_sentences = 0
        log_marginal = 0.
        total_KL = 0.
        n_samples = 10
        for sentences_x, sentences_y in val_dl:
            x_in, x_out, seq_mask_x, seq_len_x = create_batch(
                sentences_x, vocab_src, device)
            y_in, y_out, seq_mask_y, seq_len_y = create_batch(
                sentences_y, vocab_tgt, device)

            # Infer q(z|x) for this batch.
            qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x)
            pz = model.prior().expand(qz.mean.size())
            total_KL += torch.distributions.kl.kl_divergence(qz,
                                                             pz).sum().item()

            # Take s importance samples from q(z|x):
            # log int{p(x, y, z) dz} ~= log sum_z{p(x, y, z) / q(z|x)} where z ~ q(z|x)
            batch_size = x_in.size(0)
            batch_log_marginals = torch.zeros(n_samples, batch_size)
            for s in range(n_samples):

                # z ~ q(z|x)
                z = qz.sample()

                # Compute the logits according to this sample of z.
                tm_logits, lm_logits, _ = model(x_in, seq_mask_x, seq_len_x,
                                                y_in, z)

                # Compute log P(y|x, z_s)
                log_tm_prob = F.log_softmax(tm_logits, dim=-1)
                log_tm_prob = torch.gather(log_tm_prob, 2,
                                           y_out.unsqueeze(-1)).squeeze()
                log_tm_prob = (seq_mask_y.type_as(log_tm_prob) *
                               log_tm_prob).sum(dim=1)

                # Compute log P(x|z_s)
                log_lm_prob = F.log_softmax(lm_logits, dim=-1)
                log_lm_prob = torch.gather(log_lm_prob, 2,
                                           x_out.unsqueeze(-1)).squeeze()
                log_lm_prob = (seq_mask_x.type_as(log_lm_prob) *
                               log_lm_prob).sum(dim=1)

                # Compute prior probability log P(z_s) and importance weight q(z_s|x)
                log_pz = pz.log_prob(z).sum(dim=1)  # [B, latent_size] -> [B]
                log_qz = qz.log_prob(z).sum(dim=1)

                # Estimate the importance weighted estimate of (the log of) P(x, y)
                batch_log_marginals[
                    s] = log_tm_prob + log_lm_prob + log_pz - log_qz

            # Average over all samples.
            batch_log_marginal = torch.logsumexp(batch_log_marginals, dim=0) - \
                                 torch.log(torch.Tensor([n_samples]))
            log_marginal += batch_log_marginal.sum().item()  # [B] -> []
            num_sentences += batch_size
            num_predictions += (seq_len_x.sum() + seq_len_y.sum()).item()

    val_NLL = -log_marginal
    val_perplexity = np.exp(val_NLL / num_predictions)
    return val_perplexity, val_NLL / num_sentences, total_KL / num_sentences