def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, summary_writer=None): model.eval() # Create the validation dataloader. We can just bucket. val_dl = DataLoader(val_data, batch_size=hparams.batch_size, shuffle=False, num_workers=4) val_dl = BucketingParallelDataLoader(val_dl) val_ppl, val_NLL = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device) val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt, device, hparams) random_idx = np.random.choice(len(inputs)) print(f"validation perplexity = {val_ppl:,.2f}" f" -- validation NLL = {val_NLL:,.2f}" f" -- validation BLEU = {val_bleu:.2f}\n" f"- Source: {inputs[random_idx]}\n" f"- Target: {refs[random_idx]}\n" f"- Prediction: {hyps[random_idx]}") # Write validation summaries. if summary_writer is not None: summary_writer.add_scalar("validation/NLL", val_NLL, step) summary_writer.add_scalar("validation/BLEU", val_bleu, step) summary_writer.add_scalar("validation/perplexity", val_ppl, step) # Log the attention weights of the first validation sentence. with torch.no_grad(): val_sentence_x, val_sentence_y = val_data[0] x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device) y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt, device) _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in) att_weights = att_weights.squeeze().cpu().numpy() src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split() tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split() attention_summary(tgt_labels, src_labels, att_weights, summary_writer, "validation/attention", step) return { 'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl }
def train(model, optimizers, lr_schedulers, training_data, val_data, vocab_src, vocab_tgt, device, out_dir, train_step, validate, hparams): """ :param train_step: function that performs a single training step and returns training loss. Takes as inputs: model, x_in, x_out, seq_mask_x, seq_len_x, y_in, y_out, seq_mask_y, seq_len_y, hparams, step. :param validate: function that performs validation and returns validation BLEU, used for model selection. Takes as inputs: model, val_data, vocab, device, hparams, step, summary_writer. summary_writer can be None if no summaries should be made. This function should perform all evaluation, write summaries and write any validation metrics to the standard out. """ # Create a dataloader that buckets the batches. dl = DataLoader(training_data, batch_size=hparams.batch_size, shuffle=True, num_workers=4) bucketing_dl = BucketingParallelDataLoader(dl) # Save the best model based on development BLEU. ckpt = CheckPoint(model_dir=out_dir / "model", metrics=['bleu', 'likelihood']) # Keep track of some stuff in TensorBoard. summary_writer = SummaryWriter(log_dir=str(out_dir)) # Define training statistics to keep track of. tokens_start = time.time() num_tokens = 0 total_train_loss = 0. num_sentences = 0 step = 0 epoch_num = 1 # Define the evaluation function. def run_evaluation(): # Perform model validation, keep track of validation BLEU for model # selection. model.eval() metrics = validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, summary_writer=summary_writer) # Update the learning rate scheduler. lr_scheduler_step(lr_schedulers, hparams, val_score=metrics[hparams.criterion]) ckpt.update( epoch_num, step, {f"{hparams.src}-{hparams.tgt}": model}, # we save with respect to BLEU and likelihood bleu=metrics['bleu'], likelihood=metrics['likelihood']) # Start the training loop. while (epoch_num <= hparams.num_epochs) or (ckpt.no_improvement( hparams.criterion) < hparams.patience): # Train for 1 epoch. for sentences_x, sentences_y in bucketing_dl: model.train() # Perform a forward pass through the model x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in = create_noisy_batch( sentences_x, vocab_src, device, word_dropout=hparams.word_dropout) y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in = create_noisy_batch( sentences_y, vocab_tgt, device, word_dropout=hparams.word_dropout) return_dict = train_step(model, x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in, hparams, step, summary_writer=summary_writer) loss = return_dict["loss"] # Backpropagate and update gradients. loss.backward() if hparams.max_gradient_norm > 0: # TODO: do we need separate norms? nn.utils.clip_grad_norm_(model.parameters(), hparams.max_gradient_norm) optimizers["gen"].step() if "inf_z" in optimizers: optimizers["inf_z"].step() if "lagrangian" in optimizers: # We are doing maximization for this parameter group rather than minimization. Thus we # invert the direction of the gradients. for group in optimizers["lagrangian"].param_groups: for p in group["params"]: p.grad = -1 * p.grad optimizers["lagrangian"].step() # Update statistics. num_tokens += (seq_len_x.sum() + seq_len_y.sum()).item() num_sentences += x_in.size(0) total_train_loss += loss.item() * x_in.size(0) # Print training stats every now and again. if step % hparams.print_every == 0: elapsed = time.time() - tokens_start tokens_per_sec = num_tokens / elapsed if step != 0 else 0 grad_norm = gradient_norm( model, skip_null=True ) # use False if you prefer exceptions for null grad displaying = f"raw KL = {return_dict['raw_KL'].mean().item():,.2f}" # - log P(x|z) for the various source LM decoders for comp_name, comp_value in sorted(return_dict.items()): if comp_name.startswith('lm/'): displaying += f" -- {comp_name} = {-comp_value.mean().item():,.2f}" # - log P(y|z,x) for the various translation decoders for comp_name, comp_value in sorted(return_dict.items()): if comp_name.startswith('tm/'): displaying += f" -- {comp_name} = {-comp_value.mean().item():,.2f}" print( f"({epoch_num}) step {step}: " f"training loss = {total_train_loss/num_sentences:,.2f} -- " f"{displaying} -- " f"{tokens_per_sec:,.0f} tokens/s -- " f"gradient norm = {grad_norm:.2f}") summary_writer.add_scalar("train/loss", total_train_loss / num_sentences, step) num_tokens = 0 tokens_start = time.time() total_train_loss = 0. num_sentences = 0 # Zero the gradient buffer. optimizers["gen"].zero_grad() if "inf_z" in optimizers: optimizers["inf_z"].zero_grad() if "lagrangian" in optimizers: optimizers["lagrangian"].zero_grad() # Update the learning rate scheduler if needed. lr_scheduler_step(lr_schedulers, hparams) # Run evaluation every evaluate_every steps if set. if hparams.evaluate_every > 0 and step > 0 and step % hparams.evaluate_every == 0: run_evaluation() step += 1 print(f"Finished epoch {epoch_num}") # If evaluate_every is not set, we evaluate after every epoch. if hparams.evaluate_every <= 0: run_evaluation() epoch_num += 1 print(f"Finished training.") summary_writer.close() # Load the best model and run validation again, make sure to not write # summaries. best_model_info = ckpt.load_best({f"{hparams.src}-{hparams.tgt}": model}, hparams.criterion) print( f"Loaded best model (wrt {hparams.criterion}) found at step {best_model_info['step']} (epoch {best_model_info['epoch']})." ) model.eval() validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, summary_writer=None)
def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, title='xy', summary_writer=None): model.eval() # Create the validation dataloader. We can just bucket. val_dl = DataLoader(val_data, batch_size=hparams.batch_size, shuffle=False, num_workers=4) val_dl = BucketingParallelDataLoader(val_dl) val_ppl, val_KL, val_NLLs = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device) val_NLL = val_NLLs['joint/main'] val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt, device, hparams) random_idx = np.random.choice(len(inputs)) #nll_str = ' '.join('-- validation NLL {} = {:.2f}'.format(comp_name, comp_value) for comp_name, comp_value in sorted(val_NLLs.items())) nll_str = f"" # - log P(x|z) for the various source LM decoders for comp_name, comp_nll in sorted(val_NLLs.items()): if comp_name.startswith('lm/'): nll_str += f" -- {comp_name} = {comp_nll:,.2f}" # - log P(y|z,x) for the various translation decoders for comp_name, comp_nll in sorted(val_NLLs.items()): if comp_name.startswith('tm/'): nll_str += f" -- {comp_name} = {comp_nll:,.2f}" kl_str = f"-- KL = {val_KL.sum():.2f}" if isinstance(model.prior(), ProductOfDistributions): for i, p in enumerate(model.prior().distributions): kl_str += f" -- KL{i} = {val_KL[i]:.2f}" print(f"direction = {title}\n" f"validation perplexity = {val_ppl:,.2f}" f" -- BLEU = {val_bleu:.2f}" f" {kl_str}" f" {nll_str}\n" f"- Source: {inputs[random_idx]}\n" f"- Target: {refs[random_idx]}\n" f"- Prediction: {hyps[random_idx]}") if hparams.draw_translations > 0: random_idx = np.random.choice(len(inputs)) dl = DataLoader([val_data[random_idx] for _ in range(hparams.draw_translations)], batch_size=hparams.batch_size, shuffle=False, num_workers=4) dl = BucketingParallelDataLoader(dl) i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device, hparams) print("Posterior samples") print(f"- Input: {i[0]}") print(f"- Reference: {r[0]}") for h in hs: print(f"- Translation: {h}") # Write validation summaries. if summary_writer is not None: summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step) summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl, step) summary_writer.add_scalar(f"{title}/validation/KL", val_KL.sum(), step) if isinstance(model.prior(), ProductOfDistributions): for i, _ in enumerate(model.prior().distributions): summary_writer.add_scalar(f"{title}/validation/KL{i}", val_KL[i], step) for comp_name, comp_value in val_NLLs.items(): summary_writer.add_scalar(f"{title}/validation/NLL/{comp_name}", comp_value, step) # Log the attention weights of the first validation sentence. with torch.no_grad(): val_sentence_x, val_sentence_y = val_data[0] x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device) y_in, y_out, seq_mask_y, seq_len_y = create_batch([val_sentence_y], vocab_tgt, device) z = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y).sample() _, _, state, _, _ = model(x_in, seq_mask_x, seq_len_x, y_in, z) att_weights = state['att_weights'].squeeze().cpu().numpy() src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split() tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split() attention_summary(src_labels, tgt_labels, att_weights, summary_writer, f"{title}/validation/attention", step) return {'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl}
def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, title='xy', summary_writer=None): model.eval() # Create the validation dataloader. We can just bucket. val_dl = DataLoader(val_data, batch_size=hparams.batch_size, shuffle=False, num_workers=4) val_dl = BucketingParallelDataLoader(val_dl) val_ppl, val_NLL, val_KL = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device) val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt, device, hparams) random_idx = np.random.choice(len(inputs)) print(f"direction = {title}\n" f"validation perplexity = {val_ppl:,.2f}" f" -- validation NLL = {val_NLL:,.2f}" f" -- validation BLEU = {val_bleu:.2f}" f" -- validation KL = {val_KL:.2f}\n" f"- Source: {inputs[random_idx]}\n" f"- Target: {refs[random_idx]}\n" f"- Prediction: {hyps[random_idx]}") if hparams.draw_translations > 0: random_idx = np.random.choice(len(inputs)) dl = DataLoader( [val_data[random_idx] for _ in range(hparams.draw_translations)], batch_size=hparams.batch_size, shuffle=False, num_workers=4) dl = BucketingParallelDataLoader(dl) i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device, hparams) print("Posterior samples") print(f"- Input: {i[0]}") print(f"- Reference: {r[0]}") for h in hs: print(f"- Translation: {h}") # Write validation summaries. if summary_writer is not None: summary_writer.add_scalar(f"{title}/validation/NLL", val_NLL, step) summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step) summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl, step) summary_writer.add_scalar(f"{title}/validation/KL", val_KL, step) # Log the attention weights of the first validation sentence. with torch.no_grad(): val_sentence_x, val_sentence_y = val_data[0] x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device) y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt, device) z = model.approximate_posterior(x_in, seq_mask_x, seq_len_x).sample() _, _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in, z) att_weights = att_weights.squeeze().cpu().numpy() src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split() tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split() attention_summary(src_labels, tgt_labels, att_weights, summary_writer, f"{title}/validation/attention", step) return { 'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl }