コード例 #1
0
def validate(model,
             val_data,
             vocab_src,
             vocab_tgt,
             device,
             hparams,
             step,
             summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data,
                        batch_size=hparams.batch_size,
                        shuffle=False,
                        num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_NLL = _evaluate_perplexity(model, val_dl, vocab_src,
                                            vocab_tgt, device)
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src,
                                                  vocab_tgt, device, hparams)

    random_idx = np.random.choice(len(inputs))
    print(f"validation perplexity = {val_ppl:,.2f}"
          f" -- validation NLL = {val_NLL:,.2f}"
          f" -- validation BLEU = {val_bleu:.2f}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar("validation/NLL", val_NLL, step)
        summary_writer.add_scalar("validation/BLEU", val_bleu, step)
        summary_writer.add_scalar("validation/perplexity", val_ppl, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x],
                                                          vocab_src, device)
            y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt,
                                             device)
            _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in)
            att_weights = att_weights.squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src,
                                        no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt,
                                        no_filter=True)[0].split()
        attention_summary(tgt_labels, src_labels, att_weights, summary_writer,
                          "validation/attention", step)

    return {
        'bleu': val_bleu,
        'likelihood': -val_NLL,
        'nll': val_NLL,
        'ppl': val_ppl
    }
コード例 #2
0
ファイル: train.py プロジェクト: Martin-top/AEVNMT.pt
def train(model, optimizers, lr_schedulers, training_data, val_data, vocab_src,
          vocab_tgt, device, out_dir, train_step, validate, hparams):
    """
    :param train_step: function that performs a single training step and returns
                       training loss. Takes as inputs: model, x_in, x_out,
                       seq_mask_x, seq_len_x, y_in, y_out, seq_mask_y,
                       seq_len_y, hparams, step.
    :param validate: function that performs validation and returns validation
                     BLEU, used for model selection. Takes as inputs: model,
                     val_data, vocab, device, hparams, step, summary_writer.
                     summary_writer can be None if no summaries should be made.
                     This function should perform all evaluation, write
                     summaries and write any validation metrics to the
                     standard out.
    """

    # Create a dataloader that buckets the batches.
    dl = DataLoader(training_data,
                    batch_size=hparams.batch_size,
                    shuffle=True,
                    num_workers=4)
    bucketing_dl = BucketingParallelDataLoader(dl)

    # Save the best model based on development BLEU.
    ckpt = CheckPoint(model_dir=out_dir / "model",
                      metrics=['bleu', 'likelihood'])

    # Keep track of some stuff in TensorBoard.
    summary_writer = SummaryWriter(log_dir=str(out_dir))

    # Define training statistics to keep track of.
    tokens_start = time.time()
    num_tokens = 0
    total_train_loss = 0.
    num_sentences = 0
    step = 0
    epoch_num = 1

    # Define the evaluation function.
    def run_evaluation():
        # Perform model validation, keep track of validation BLEU for model
        # selection.
        model.eval()
        metrics = validate(model,
                           val_data,
                           vocab_src,
                           vocab_tgt,
                           device,
                           hparams,
                           step,
                           summary_writer=summary_writer)

        # Update the learning rate scheduler.
        lr_scheduler_step(lr_schedulers,
                          hparams,
                          val_score=metrics[hparams.criterion])

        ckpt.update(
            epoch_num,
            step,
            {f"{hparams.src}-{hparams.tgt}": model},
            # we save with respect to BLEU and likelihood
            bleu=metrics['bleu'],
            likelihood=metrics['likelihood'])

    # Start the training loop.
    while (epoch_num <= hparams.num_epochs) or (ckpt.no_improvement(
            hparams.criterion) < hparams.patience):

        # Train for 1 epoch.
        for sentences_x, sentences_y in bucketing_dl:
            model.train()

            # Perform a forward pass through the model
            x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in = create_noisy_batch(
                sentences_x,
                vocab_src,
                device,
                word_dropout=hparams.word_dropout)
            y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in = create_noisy_batch(
                sentences_y,
                vocab_tgt,
                device,
                word_dropout=hparams.word_dropout)
            return_dict = train_step(model,
                                     x_in,
                                     x_out,
                                     seq_mask_x,
                                     seq_len_x,
                                     noisy_x_in,
                                     y_in,
                                     y_out,
                                     seq_mask_y,
                                     seq_len_y,
                                     noisy_y_in,
                                     hparams,
                                     step,
                                     summary_writer=summary_writer)
            loss = return_dict["loss"]

            # Backpropagate and update gradients.
            loss.backward()
            if hparams.max_gradient_norm > 0:
                # TODO: do we need separate norms?
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hparams.max_gradient_norm)
            optimizers["gen"].step()
            if "inf_z" in optimizers: optimizers["inf_z"].step()
            if "lagrangian" in optimizers:
                # We are doing maximization for this parameter group rather than minimization. Thus we
                # invert the direction of the gradients.
                for group in optimizers["lagrangian"].param_groups:
                    for p in group["params"]:
                        p.grad = -1 * p.grad
                optimizers["lagrangian"].step()

            # Update statistics.
            num_tokens += (seq_len_x.sum() + seq_len_y.sum()).item()
            num_sentences += x_in.size(0)
            total_train_loss += loss.item() * x_in.size(0)

            # Print training stats every now and again.
            if step % hparams.print_every == 0:
                elapsed = time.time() - tokens_start
                tokens_per_sec = num_tokens / elapsed if step != 0 else 0
                grad_norm = gradient_norm(
                    model, skip_null=True
                )  # use False if you prefer exceptions for null grad

                displaying = f"raw KL = {return_dict['raw_KL'].mean().item():,.2f}"
                # - log P(x|z) for the various source LM decoders
                for comp_name, comp_value in sorted(return_dict.items()):
                    if comp_name.startswith('lm/'):
                        displaying += f" -- {comp_name} = {-comp_value.mean().item():,.2f}"
                # - log P(y|z,x) for the various translation decoders
                for comp_name, comp_value in sorted(return_dict.items()):
                    if comp_name.startswith('tm/'):
                        displaying += f" -- {comp_name} = {-comp_value.mean().item():,.2f}"
                print(
                    f"({epoch_num}) step {step}: "
                    f"training loss = {total_train_loss/num_sentences:,.2f} -- "
                    f"{displaying} -- "
                    f"{tokens_per_sec:,.0f} tokens/s -- "
                    f"gradient norm = {grad_norm:.2f}")
                summary_writer.add_scalar("train/loss",
                                          total_train_loss / num_sentences,
                                          step)
                num_tokens = 0
                tokens_start = time.time()
                total_train_loss = 0.
                num_sentences = 0

            # Zero the gradient buffer.
            optimizers["gen"].zero_grad()
            if "inf_z" in optimizers: optimizers["inf_z"].zero_grad()
            if "lagrangian" in optimizers: optimizers["lagrangian"].zero_grad()

            # Update the learning rate scheduler if needed.
            lr_scheduler_step(lr_schedulers, hparams)

            # Run evaluation every evaluate_every steps if set.
            if hparams.evaluate_every > 0 and step > 0 and step % hparams.evaluate_every == 0:
                run_evaluation()

            step += 1

        print(f"Finished epoch {epoch_num}")

        # If evaluate_every is not set, we evaluate after every epoch.
        if hparams.evaluate_every <= 0:
            run_evaluation()

        epoch_num += 1

    print(f"Finished training.")
    summary_writer.close()

    # Load the best model and run validation again, make sure to not write
    # summaries.
    best_model_info = ckpt.load_best({f"{hparams.src}-{hparams.tgt}": model},
                                     hparams.criterion)
    print(
        f"Loaded best model (wrt {hparams.criterion}) found at step {best_model_info['step']} (epoch {best_model_info['epoch']})."
    )
    model.eval()
    validate(model,
             val_data,
             vocab_src,
             vocab_tgt,
             device,
             hparams,
             step,
             summary_writer=None)
コード例 #3
0
def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, title='xy', summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data, batch_size=hparams.batch_size,
                        shuffle=False, num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_KL, val_NLLs = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device)
    val_NLL = val_NLLs['joint/main']
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt,
                                                  device, hparams)

    random_idx = np.random.choice(len(inputs))
    #nll_str = ' '.join('-- validation NLL {} = {:.2f}'.format(comp_name, comp_value)  for comp_name, comp_value in sorted(val_NLLs.items()))
    nll_str = f""
    # - log P(x|z) for the various source LM decoders
    for comp_name, comp_nll in sorted(val_NLLs.items()):
        if comp_name.startswith('lm/'):
            nll_str += f" -- {comp_name} = {comp_nll:,.2f}"
    # - log P(y|z,x) for the various translation decoders
    for comp_name, comp_nll in sorted(val_NLLs.items()):
        if comp_name.startswith('tm/'):
            nll_str += f" -- {comp_name} = {comp_nll:,.2f}"
    
    kl_str = f"-- KL = {val_KL.sum():.2f}"
    if isinstance(model.prior(), ProductOfDistributions):
        for i, p in enumerate(model.prior().distributions):
            kl_str += f" -- KL{i} = {val_KL[i]:.2f}" 
        
    print(f"direction = {title}\n"
          f"validation perplexity = {val_ppl:,.2f}"
          f" -- BLEU = {val_bleu:.2f}"
          f" {kl_str}"
          f" {nll_str}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    if hparams.draw_translations > 0:
        random_idx = np.random.choice(len(inputs))
        dl = DataLoader([val_data[random_idx] for _ in range(hparams.draw_translations)], batch_size=hparams.batch_size, shuffle=False, num_workers=4)
        dl = BucketingParallelDataLoader(dl)
        i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device, hparams)
        print("Posterior samples")
        print(f"- Input: {i[0]}")
        print(f"- Reference: {r[0]}")
        for h in hs:
            print(f"- Translation: {h}")
    
    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step)
        summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl, step)
        summary_writer.add_scalar(f"{title}/validation/KL", val_KL.sum(), step)
        if isinstance(model.prior(), ProductOfDistributions):
            for i, _ in enumerate(model.prior().distributions):
                summary_writer.add_scalar(f"{title}/validation/KL{i}", val_KL[i], step)
        for comp_name, comp_value in val_NLLs.items():
            summary_writer.add_scalar(f"{title}/validation/NLL/{comp_name}", comp_value, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device)
            y_in, y_out, seq_mask_y, seq_len_y = create_batch([val_sentence_y], vocab_tgt, device)
            z = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y).sample()
            _, _, state, _, _ = model(x_in, seq_mask_x, seq_len_x, y_in, z)
            att_weights = state['att_weights'].squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split()
        attention_summary(src_labels, tgt_labels, att_weights, summary_writer,
                          f"{title}/validation/attention", step)

    return {'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl}
コード例 #4
0
ファイル: aevnmt_helper.py プロジェクト: wilkeraziz/AEVNMT.pt
def validate(model,
             val_data,
             vocab_src,
             vocab_tgt,
             device,
             hparams,
             step,
             title='xy',
             summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data,
                        batch_size=hparams.batch_size,
                        shuffle=False,
                        num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_NLL, val_KL = _evaluate_perplexity(model, val_dl, vocab_src,
                                                    vocab_tgt, device)
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src,
                                                  vocab_tgt, device, hparams)

    random_idx = np.random.choice(len(inputs))
    print(f"direction = {title}\n"
          f"validation perplexity = {val_ppl:,.2f}"
          f" -- validation NLL = {val_NLL:,.2f}"
          f" -- validation BLEU = {val_bleu:.2f}"
          f" -- validation KL = {val_KL:.2f}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    if hparams.draw_translations > 0:
        random_idx = np.random.choice(len(inputs))
        dl = DataLoader(
            [val_data[random_idx] for _ in range(hparams.draw_translations)],
            batch_size=hparams.batch_size,
            shuffle=False,
            num_workers=4)
        dl = BucketingParallelDataLoader(dl)
        i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device,
                                      hparams)
        print("Posterior samples")
        print(f"- Input: {i[0]}")
        print(f"- Reference: {r[0]}")
        for h in hs:
            print(f"- Translation: {h}")

    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar(f"{title}/validation/NLL", val_NLL, step)
        summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step)
        summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl,
                                  step)
        summary_writer.add_scalar(f"{title}/validation/KL", val_KL, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x],
                                                          vocab_src, device)
            y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt,
                                             device)
            z = model.approximate_posterior(x_in, seq_mask_x,
                                            seq_len_x).sample()
            _, _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in, z)
            att_weights = att_weights.squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src,
                                        no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt,
                                        no_filter=True)[0].split()
        attention_summary(src_labels, tgt_labels, att_weights, summary_writer,
                          f"{title}/validation/attention", step)

    return {
        'bleu': val_bleu,
        'likelihood': -val_NLL,
        'nll': val_NLL,
        'ppl': val_ppl
    }