def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, summary_writer=None): model.eval() # Create the validation dataloader. We can just bucket. val_dl = DataLoader(val_data, batch_size=hparams.batch_size, shuffle=False, num_workers=4) val_dl = BucketingParallelDataLoader(val_dl) val_ppl, val_NLL = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device) val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt, device, hparams) random_idx = np.random.choice(len(inputs)) print(f"validation perplexity = {val_ppl:,.2f}" f" -- validation NLL = {val_NLL:,.2f}" f" -- validation BLEU = {val_bleu:.2f}\n" f"- Source: {inputs[random_idx]}\n" f"- Target: {refs[random_idx]}\n" f"- Prediction: {hyps[random_idx]}") # Write validation summaries. if summary_writer is not None: summary_writer.add_scalar("validation/NLL", val_NLL, step) summary_writer.add_scalar("validation/BLEU", val_bleu, step) summary_writer.add_scalar("validation/perplexity", val_ppl, step) # Log the attention weights of the first validation sentence. with torch.no_grad(): val_sentence_x, val_sentence_y = val_data[0] x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device) y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt, device) _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in) att_weights = att_weights.squeeze().cpu().numpy() src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split() tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split() attention_summary(tgt_labels, src_labels, att_weights, summary_writer, "validation/attention", step) return { 'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl }
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams): model.eval() with torch.no_grad(): x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences, vocab_src, device) encoder_outputs, encoder_final = model.encode(x_in, seq_len_x) hidden = model.init_decoder(encoder_outputs, encoder_final) if hparams.sample_decoding: raw_hypothesis = sampling_decode( model.decoder, model.tgt_embed, model.generate, hidden, encoder_outputs, encoder_final, seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.max_decoding_length) elif hparams.beam_width <= 1: raw_hypothesis = greedy_decode( model.decoder, model.tgt_embed, model.generate, hidden, encoder_outputs, encoder_final, seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.max_decoding_length) else: raw_hypothesis = beam_search( model.decoder, model.tgt_embed, model.generate, vocab_tgt.size(), hidden, encoder_outputs, encoder_final, seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.beam_width, hparams.length_penalty_factor, hparams.max_decoding_length) hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt) return hypothesis
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams, deterministic=True): model.eval() with torch.no_grad(): x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences, vocab_src, device) # For translation we use the approximate posterior mean. qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x) z = qz.mean if deterministic else qz.sample() encoder_outputs, encoder_final = model.encode(x_in, seq_len_x, z) hidden = model.init_decoder(encoder_outputs, encoder_final, z) if hparams.beam_width <= 1: raw_hypothesis = greedy_decode( model.decoder, model.tgt_embed, model.generate, hidden, encoder_outputs, encoder_final, seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.max_decoding_length) else: raw_hypothesis = beam_search( model.decoder, model.tgt_embed, model.generate, vocab_tgt.size(), hidden, encoder_outputs, encoder_final, seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.beam_width, hparams.length_penalty_factor, hparams.max_decoding_length) hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt) return hypothesis
def generate_senvae(model, input_sentences, num_samples, vocab_src, device, hparams, deterministic=True): model.eval() with torch.no_grad(): if input_sentences is not None: x_in, _, seq_mask_x, seq_len_x = create_batch( input_sentences, vocab_src, device) qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y=x_in, seq_mask_y=seq_mask_x, seq_len_y=seq_len_x) else: qz = model.prior().expand((num_samples, )) # TODO: restore some form of deterministic decoding #z = qz.mean if deterministic else qz.sample() z = qz.sample() if isinstance(model.language_model, TransformerLM): hidden = None else: hidden = model.language_model.init(z) if hparams.decoding.sample: raw_hypothesis = model.language_model.sample( z, max_len=hparams.decoding.max_length, greedy=False) elif hparams.decoding.beam_width <= 1: raw_hypothesis = model.language_model.sample( z, max_len=hparams.decoding.max_length, greedy=True) else: raise NotImplementedError """ raw_hypothesis = beam_search( model.language_model, model.language_model.embedder, model.language_model.generate, vocab_src.size(), None, None, None, None, None, vocab_src[SOS_TOKEN], vocab_src[EOS_TOKEN], vocab_src[PAD_TOKEN], hparams.decoding.beam_width, hparams.decoding.length_penalty_factor, hparams.decoding.max_length, z) """ hypothesis = batch_to_sentences(raw_hypothesis, vocab_src) return hypothesis
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams, deterministic=True): # TODO: this code should be in the translation model class model.eval() with torch.no_grad(): x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences, vocab_src, device) # For translation we use the approximate posterior mean. qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y=x_in, seq_mask_y=seq_mask_x, seq_len_y=seq_len_x) # TODO: here we need a prediction net! # TODO: restore some form of deterministic decoding #z = qz.mean if deterministic else qz.sample() z = qz.sample() encoder_outputs, encoder_final = model.translation_model.encode(x_in, seq_len_x, z) hidden = model.translation_model.init_decoder(encoder_outputs, encoder_final, z) if hparams.sample_decoding: # TODO: we could use the new version below #raw_hypothesis = model.translation_model.sample(x_in, seq_mask_x, seq_len_x, z, # max_len=hparams.max_decoding_length, greedy=False) raw_hypothesis = sampling_decode( model.translation_model.decoder, model.translation_model.tgt_embed, model.translation_model.generate, hidden, encoder_outputs, encoder_final, seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.max_decoding_length, z if hparams.feed_z else None) elif hparams.beam_width <= 1: # TODO: we could use the new version below #raw_hypothesis = model.translation_model.sample(x_in, seq_mask_x, seq_len_x, z, # max_len=hparams.max_decoding_length, greedy=True) raw_hypothesis = greedy_decode( model.translation_model.decoder, model.translation_model.tgt_embed, model.translation_model.generate, hidden, encoder_outputs, encoder_final, seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.max_decoding_length, z if hparams.feed_z else None) else: raw_hypothesis = beam_search( model.translation_model.decoder, model.translation_model.tgt_embed, model.translation_model.generate, vocab_tgt.size(), hidden, encoder_outputs, encoder_final, seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.beam_width, hparams.length_penalty_factor, hparams.max_decoding_length, z if hparams.feed_z else None) hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt) return hypothesis
def _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device): model.eval() with torch.no_grad(): num_predictions = 0 num_sentences = 0 val_NLL = 0. for sentences_x, sentences_y in val_dl: x_in, _, seq_mask_x, seq_len_x = create_batch(sentences_x, vocab_src, device) y_in, y_out, _, seq_len_y = create_batch(sentences_y, vocab_tgt, device) # Do a forward pass and compute the validation loss of this batch. logits, _ = model(x_in, seq_mask_x, seq_len_x, y_in) batch_NLL = model.loss(logits, y_out, reduction="sum")["loss"] val_NLL += batch_NLL.item() num_sentences += x_in.size(0) num_predictions += seq_len_y.sum().item() val_perplexity = np.exp(val_NLL / num_predictions) return val_perplexity, val_NLL/num_sentences
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams): model.eval() with torch.no_grad(): x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences, vocab_src, device) if isinstance(model.translation_model, TransformerTM): encoder_outputs, seq_len_x = model.translation_model.encode( x_in, seq_len_x, None) encoder_final = None hidden = None else: encoder_outputs, encoder_final = model.translation_model.encode( x_in, seq_len_x, None) hidden = model.translation_model.init_decoder( encoder_outputs, encoder_final, None) if hparams.decoding.sample: raw_hypothesis = model.translation_model.sample( x_in, seq_mask_x, seq_len_x, None, max_len=hparams.decoding.max_length, greedy=False) elif hparams.decoding.beam_width <= 1: raw_hypothesis = raw_hypothesis = model.translation_model.sample( x_in, seq_mask_x, seq_len_x, None, max_len=hparams.decoding.max_length, greedy=True) else: raw_hypothesis = beam_search( model.translation_model.decoder, model.translation_model.tgt_embed, model.translation_model.generate, vocab_tgt.size(), hidden, encoder_outputs, encoder_final, seq_mask_x, seq_len_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN], vocab_tgt[PAD_TOKEN], hparams.decoding.beam_width, hparams.decoding.length_penalty_factor, hparams.decoding.max_length, None) hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt) return hypothesis
def _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device): model.eval() with torch.no_grad(): num_predictions = 0 num_sentences = 0 log_marginal = defaultdict(float) total_KL = 0. n_samples = 10 for sentences_x, sentences_y in val_dl: x_in, x_out, seq_mask_x, seq_len_x = create_batch(sentences_x, vocab_src, device) y_in, y_out, seq_mask_y, seq_len_y = create_batch(sentences_y, vocab_tgt, device) # Infer q(z|x) for this batch. qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y) pz = model.prior() if isinstance(qz, ProductOfDistributions): total_KL += torch.cat( [kl_divergence(qi, pi).sum(0).unsqueeze(-1) for qi, pi in zip(qz.distributions, pz.distributions)], -1) else: total_KL += kl_divergence(qz, pz).sum(0) # Take s importance samples from q(z|x): # log int{p(x, y, z) dz} ~= log sum_z{p(x, y, z) / q(z|x)} where z ~ q(z|x) batch_size = x_in.size(0) batch_log_marginals = defaultdict(lambda: torch.zeros(n_samples, batch_size)) for s in range(n_samples): # z ~ q(z|x) z = qz.sample() # Compute the logits according to this sample of z. tm_likelihood, lm_likelihood, _, aux_lm_likelihoods, aux_tm_likelihoods = model(x_in, seq_mask_x, seq_len_x, y_in, z) # Compute log P(y|x, z_s) log_tm_prob = model.translation_model.log_prob(tm_likelihood, y_out) # Compute log P(x|z_s) log_lm_prob = model.language_model.log_prob(lm_likelihood, x_out) # Compute prior probability log P(z_s) and importance weight q(z_s|x) log_pz = pz.log_prob(z) log_qz = qz.log_prob(z) # Estimate the importance weighted estimate of (the log of) P(x, y) batch_log_marginals['joint/main'][s] = log_tm_prob + log_lm_prob + log_pz - log_qz batch_log_marginals['lm/main'][s] = log_lm_prob + log_pz - log_qz batch_log_marginals['tm/main'][s] = log_tm_prob + log_pz - log_qz for aux_comp, aux_px_z in aux_lm_likelihoods.items(): batch_log_marginals['lm/' + aux_comp][s] = model.log_likelihood_lm(aux_comp, aux_px_z, x_out) + log_pz - log_qz for aux_comp, aux_py_xz in aux_tm_likelihoods.items(): batch_log_marginals['tm/' + aux_comp][s] = model.log_likelihood_tm(aux_comp, aux_py_xz, y_out) + log_pz - log_qz for comp_name, log_marginals in batch_log_marginals.items(): # Average over all samples. batch_avg = torch.logsumexp(log_marginals, dim=0) - torch.log(torch.Tensor([n_samples])) log_marginal[comp_name] = log_marginal[comp_name] + batch_avg.sum().item() num_sentences += batch_size num_predictions += (seq_len_x.sum() + seq_len_y.sum()).item() val_NLL = -log_marginal['joint/main'] val_perplexity = np.exp(val_NLL / num_predictions) NLLs = {comp_name: -value / num_sentences for comp_name, value in log_marginal.items()} return val_perplexity, total_KL/num_sentences, NLLs
def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, title='xy', summary_writer=None): model.eval() # Create the validation dataloader. We can just bucket. val_dl = DataLoader(val_data, batch_size=hparams.batch_size, shuffle=False, num_workers=4) val_dl = BucketingParallelDataLoader(val_dl) val_ppl, val_KL, val_NLLs = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device) val_NLL = val_NLLs['joint/main'] val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt, device, hparams) random_idx = np.random.choice(len(inputs)) #nll_str = ' '.join('-- validation NLL {} = {:.2f}'.format(comp_name, comp_value) for comp_name, comp_value in sorted(val_NLLs.items())) nll_str = f"" # - log P(x|z) for the various source LM decoders for comp_name, comp_nll in sorted(val_NLLs.items()): if comp_name.startswith('lm/'): nll_str += f" -- {comp_name} = {comp_nll:,.2f}" # - log P(y|z,x) for the various translation decoders for comp_name, comp_nll in sorted(val_NLLs.items()): if comp_name.startswith('tm/'): nll_str += f" -- {comp_name} = {comp_nll:,.2f}" kl_str = f"-- KL = {val_KL.sum():.2f}" if isinstance(model.prior(), ProductOfDistributions): for i, p in enumerate(model.prior().distributions): kl_str += f" -- KL{i} = {val_KL[i]:.2f}" print(f"direction = {title}\n" f"validation perplexity = {val_ppl:,.2f}" f" -- BLEU = {val_bleu:.2f}" f" {kl_str}" f" {nll_str}\n" f"- Source: {inputs[random_idx]}\n" f"- Target: {refs[random_idx]}\n" f"- Prediction: {hyps[random_idx]}") if hparams.draw_translations > 0: random_idx = np.random.choice(len(inputs)) dl = DataLoader([val_data[random_idx] for _ in range(hparams.draw_translations)], batch_size=hparams.batch_size, shuffle=False, num_workers=4) dl = BucketingParallelDataLoader(dl) i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device, hparams) print("Posterior samples") print(f"- Input: {i[0]}") print(f"- Reference: {r[0]}") for h in hs: print(f"- Translation: {h}") # Write validation summaries. if summary_writer is not None: summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step) summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl, step) summary_writer.add_scalar(f"{title}/validation/KL", val_KL.sum(), step) if isinstance(model.prior(), ProductOfDistributions): for i, _ in enumerate(model.prior().distributions): summary_writer.add_scalar(f"{title}/validation/KL{i}", val_KL[i], step) for comp_name, comp_value in val_NLLs.items(): summary_writer.add_scalar(f"{title}/validation/NLL/{comp_name}", comp_value, step) # Log the attention weights of the first validation sentence. with torch.no_grad(): val_sentence_x, val_sentence_y = val_data[0] x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device) y_in, y_out, seq_mask_y, seq_len_y = create_batch([val_sentence_y], vocab_tgt, device) z = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y).sample() _, _, state, _, _ = model(x_in, seq_mask_x, seq_len_x, y_in, z) att_weights = state['att_weights'].squeeze().cpu().numpy() src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split() tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split() attention_summary(src_labels, tgt_labels, att_weights, summary_writer, f"{title}/validation/attention", step) return {'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl}
def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, title='xy', summary_writer=None): model.eval() # Create the validation dataloader. We can just bucket. val_dl = DataLoader(val_data, batch_size=hparams.batch_size, shuffle=False, num_workers=4) val_dl = BucketingParallelDataLoader(val_dl) val_ppl, val_NLL, val_KL = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device) val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt, device, hparams) random_idx = np.random.choice(len(inputs)) print(f"direction = {title}\n" f"validation perplexity = {val_ppl:,.2f}" f" -- validation NLL = {val_NLL:,.2f}" f" -- validation BLEU = {val_bleu:.2f}" f" -- validation KL = {val_KL:.2f}\n" f"- Source: {inputs[random_idx]}\n" f"- Target: {refs[random_idx]}\n" f"- Prediction: {hyps[random_idx]}") if hparams.draw_translations > 0: random_idx = np.random.choice(len(inputs)) dl = DataLoader( [val_data[random_idx] for _ in range(hparams.draw_translations)], batch_size=hparams.batch_size, shuffle=False, num_workers=4) dl = BucketingParallelDataLoader(dl) i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device, hparams) print("Posterior samples") print(f"- Input: {i[0]}") print(f"- Reference: {r[0]}") for h in hs: print(f"- Translation: {h}") # Write validation summaries. if summary_writer is not None: summary_writer.add_scalar(f"{title}/validation/NLL", val_NLL, step) summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step) summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl, step) summary_writer.add_scalar(f"{title}/validation/KL", val_KL, step) # Log the attention weights of the first validation sentence. with torch.no_grad(): val_sentence_x, val_sentence_y = val_data[0] x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device) y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt, device) z = model.approximate_posterior(x_in, seq_mask_x, seq_len_x).sample() _, _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in, z) att_weights = att_weights.squeeze().cpu().numpy() src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split() tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split() attention_summary(src_labels, tgt_labels, att_weights, summary_writer, f"{title}/validation/attention", step) return { 'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl }
def _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device): model.eval() with torch.no_grad(): num_predictions = 0 num_sentences = 0 log_marginal = 0. total_KL = 0. n_samples = 10 for sentences_x, sentences_y in val_dl: x_in, x_out, seq_mask_x, seq_len_x = create_batch( sentences_x, vocab_src, device) y_in, y_out, seq_mask_y, seq_len_y = create_batch( sentences_y, vocab_tgt, device) # Infer q(z|x) for this batch. qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x) pz = model.prior().expand(qz.mean.size()) total_KL += torch.distributions.kl.kl_divergence(qz, pz).sum().item() # Take s importance samples from q(z|x): # log int{p(x, y, z) dz} ~= log sum_z{p(x, y, z) / q(z|x)} where z ~ q(z|x) batch_size = x_in.size(0) batch_log_marginals = torch.zeros(n_samples, batch_size) for s in range(n_samples): # z ~ q(z|x) z = qz.sample() # Compute the logits according to this sample of z. tm_logits, lm_logits, _ = model(x_in, seq_mask_x, seq_len_x, y_in, z) # Compute log P(y|x, z_s) log_tm_prob = F.log_softmax(tm_logits, dim=-1) log_tm_prob = torch.gather(log_tm_prob, 2, y_out.unsqueeze(-1)).squeeze() log_tm_prob = (seq_mask_y.type_as(log_tm_prob) * log_tm_prob).sum(dim=1) # Compute log P(x|z_s) log_lm_prob = F.log_softmax(lm_logits, dim=-1) log_lm_prob = torch.gather(log_lm_prob, 2, x_out.unsqueeze(-1)).squeeze() log_lm_prob = (seq_mask_x.type_as(log_lm_prob) * log_lm_prob).sum(dim=1) # Compute prior probability log P(z_s) and importance weight q(z_s|x) log_pz = pz.log_prob(z).sum(dim=1) # [B, latent_size] -> [B] log_qz = qz.log_prob(z).sum(dim=1) # Estimate the importance weighted estimate of (the log of) P(x, y) batch_log_marginals[ s] = log_tm_prob + log_lm_prob + log_pz - log_qz # Average over all samples. batch_log_marginal = torch.logsumexp(batch_log_marginals, dim=0) - \ torch.log(torch.Tensor([n_samples])) log_marginal += batch_log_marginal.sum().item() # [B] -> [] num_sentences += batch_size num_predictions += (seq_len_x.sum() + seq_len_y.sum()).item() val_NLL = -log_marginal val_perplexity = np.exp(val_NLL / num_predictions) return val_perplexity, val_NLL / num_sentences, total_KL / num_sentences