コード例 #1
0
def validate(model,
             val_data,
             vocab_src,
             vocab_tgt,
             device,
             hparams,
             step,
             summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data,
                        batch_size=hparams.batch_size,
                        shuffle=False,
                        num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_NLL = _evaluate_perplexity(model, val_dl, vocab_src,
                                            vocab_tgt, device)
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src,
                                                  vocab_tgt, device, hparams)

    random_idx = np.random.choice(len(inputs))
    print(f"validation perplexity = {val_ppl:,.2f}"
          f" -- validation NLL = {val_NLL:,.2f}"
          f" -- validation BLEU = {val_bleu:.2f}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar("validation/NLL", val_NLL, step)
        summary_writer.add_scalar("validation/BLEU", val_bleu, step)
        summary_writer.add_scalar("validation/perplexity", val_ppl, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x],
                                                          vocab_src, device)
            y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt,
                                             device)
            _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in)
            att_weights = att_weights.squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src,
                                        no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt,
                                        no_filter=True)[0].split()
        attention_summary(tgt_labels, src_labels, att_weights, summary_writer,
                          "validation/attention", step)

    return {
        'bleu': val_bleu,
        'likelihood': -val_NLL,
        'nll': val_NLL,
        'ppl': val_ppl
    }
コード例 #2
0
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams):
    model.eval()
    with torch.no_grad():
        x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences,
                                                      vocab_src, device)
        encoder_outputs, encoder_final = model.encode(x_in, seq_len_x)
        hidden = model.init_decoder(encoder_outputs, encoder_final)
        if hparams.sample_decoding:
            raw_hypothesis = sampling_decode(
                model.decoder, model.tgt_embed, model.generate, hidden,
                encoder_outputs, encoder_final, seq_mask_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length)
        elif hparams.beam_width <= 1:
            raw_hypothesis = greedy_decode(
                model.decoder, model.tgt_embed, model.generate, hidden,
                encoder_outputs, encoder_final, seq_mask_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length)
        else:
            raw_hypothesis = beam_search(
                model.decoder, model.tgt_embed, model.generate,
                vocab_tgt.size(), hidden, encoder_outputs, encoder_final,
                seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.beam_width,
                hparams.length_penalty_factor, hparams.max_decoding_length)
    hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
    return hypothesis
コード例 #3
0
ファイル: aevnmt_helper.py プロジェクト: wilkeraziz/AEVNMT.pt
def translate(model,
              input_sentences,
              vocab_src,
              vocab_tgt,
              device,
              hparams,
              deterministic=True):
    model.eval()
    with torch.no_grad():
        x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences,
                                                      vocab_src, device)

        # For translation we use the approximate posterior mean.
        qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x)
        z = qz.mean if deterministic else qz.sample()

        encoder_outputs, encoder_final = model.encode(x_in, seq_len_x, z)
        hidden = model.init_decoder(encoder_outputs, encoder_final, z)
        if hparams.beam_width <= 1:
            raw_hypothesis = greedy_decode(
                model.decoder, model.tgt_embed, model.generate, hidden,
                encoder_outputs, encoder_final, seq_mask_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length)
        else:
            raw_hypothesis = beam_search(
                model.decoder, model.tgt_embed, model.generate,
                vocab_tgt.size(), hidden, encoder_outputs, encoder_final,
                seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.beam_width,
                hparams.length_penalty_factor, hparams.max_decoding_length)
    hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
    return hypothesis
コード例 #4
0
def generate_senvae(model,
                    input_sentences,
                    num_samples,
                    vocab_src,
                    device,
                    hparams,
                    deterministic=True):

    model.eval()
    with torch.no_grad():

        if input_sentences is not None:
            x_in, _, seq_mask_x, seq_len_x = create_batch(
                input_sentences, vocab_src, device)
            qz = model.approximate_posterior(x_in,
                                             seq_mask_x,
                                             seq_len_x,
                                             y=x_in,
                                             seq_mask_y=seq_mask_x,
                                             seq_len_y=seq_len_x)
        else:
            qz = model.prior().expand((num_samples, ))

        # TODO: restore some form of deterministic decoding
        #z = qz.mean if deterministic else qz.sample()
        z = qz.sample()

        if isinstance(model.language_model, TransformerLM):
            hidden = None
        else:
            hidden = model.language_model.init(z)

        if hparams.decoding.sample:
            raw_hypothesis = model.language_model.sample(
                z, max_len=hparams.decoding.max_length, greedy=False)

        elif hparams.decoding.beam_width <= 1:
            raw_hypothesis = model.language_model.sample(
                z, max_len=hparams.decoding.max_length, greedy=True)

        else:
            raise NotImplementedError
            """
            raw_hypothesis = beam_search(
                model.language_model,
                model.language_model.embedder,
                model.language_model.generate,
                vocab_src.size(), None, None,
                None, None, None,
                vocab_src[SOS_TOKEN], vocab_src[EOS_TOKEN],
                vocab_src[PAD_TOKEN], hparams.decoding.beam_width,
                hparams.decoding.length_penalty_factor,
                hparams.decoding.max_length,
                z)
            """

    hypothesis = batch_to_sentences(raw_hypothesis, vocab_src)
    return hypothesis
コード例 #5
0
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams, deterministic=True):
    # TODO: this code should be in the translation model class
    model.eval()
    with torch.no_grad():
        x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences, vocab_src, device)

        # For translation we use the approximate posterior mean.
        qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x,
                y=x_in, seq_mask_y=seq_mask_x, seq_len_y=seq_len_x) # TODO: here we need a prediction net!
        # TODO: restore some form of deterministic decoding
        #z = qz.mean if deterministic else qz.sample()
        z = qz.sample()

        encoder_outputs, encoder_final = model.translation_model.encode(x_in, seq_len_x, z)
        hidden = model.translation_model.init_decoder(encoder_outputs, encoder_final, z)

        if hparams.sample_decoding:
            # TODO: we could use the new version below
            #raw_hypothesis = model.translation_model.sample(x_in, seq_mask_x, seq_len_x, z, 
            #    max_len=hparams.max_decoding_length, greedy=False)
            raw_hypothesis = sampling_decode(
                model.translation_model.decoder, 
                model.translation_model.tgt_embed,
                model.translation_model.generate, hidden,
                encoder_outputs, encoder_final,
                seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length,
                z if hparams.feed_z else None)
        elif hparams.beam_width <= 1:
            # TODO: we could use the new version below
            #raw_hypothesis = model.translation_model.sample(x_in, seq_mask_x, seq_len_x, z, 
            #    max_len=hparams.max_decoding_length, greedy=True)
            raw_hypothesis = greedy_decode(
                model.translation_model.decoder, 
                model.translation_model.tgt_embed,
                model.translation_model.generate, hidden,
                encoder_outputs, encoder_final,
                seq_mask_x, vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.max_decoding_length,
                z if hparams.feed_z else None)
        else:
            raw_hypothesis = beam_search(
                model.translation_model.decoder, 
                model.translation_model.tgt_embed, 
                model.translation_model.generate,
                vocab_tgt.size(), hidden, encoder_outputs,
                encoder_final, seq_mask_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.beam_width,
                hparams.length_penalty_factor,
                hparams.max_decoding_length,
                z if hparams.feed_z else None)

    hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
    return hypothesis
コード例 #6
0
def translate(model, input_sentences, vocab_src, vocab_tgt, device, hparams):
    model.eval()
    with torch.no_grad():
        x_in, _, seq_mask_x, seq_len_x = create_batch(input_sentences,
                                                      vocab_src, device)
        if isinstance(model.translation_model, TransformerTM):
            encoder_outputs, seq_len_x = model.translation_model.encode(
                x_in, seq_len_x, None)
            encoder_final = None
            hidden = None
        else:
            encoder_outputs, encoder_final = model.translation_model.encode(
                x_in, seq_len_x, None)
            hidden = model.translation_model.init_decoder(
                encoder_outputs, encoder_final, None)
        if hparams.decoding.sample:
            raw_hypothesis = model.translation_model.sample(
                x_in,
                seq_mask_x,
                seq_len_x,
                None,
                max_len=hparams.decoding.max_length,
                greedy=False)
        elif hparams.decoding.beam_width <= 1:
            raw_hypothesis = raw_hypothesis = model.translation_model.sample(
                x_in,
                seq_mask_x,
                seq_len_x,
                None,
                max_len=hparams.decoding.max_length,
                greedy=True)
        else:
            raw_hypothesis = beam_search(
                model.translation_model.decoder,
                model.translation_model.tgt_embed,
                model.translation_model.generate, vocab_tgt.size(), hidden,
                encoder_outputs, encoder_final, seq_mask_x, seq_len_x,
                vocab_tgt[SOS_TOKEN], vocab_tgt[EOS_TOKEN],
                vocab_tgt[PAD_TOKEN], hparams.decoding.beam_width,
                hparams.decoding.length_penalty_factor,
                hparams.decoding.max_length, None)

    hypothesis = batch_to_sentences(raw_hypothesis, vocab_tgt)
    return hypothesis
コード例 #7
0
def validate(model, val_data, vocab_src, vocab_tgt, device, hparams, step, title='xy', summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data, batch_size=hparams.batch_size,
                        shuffle=False, num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_KL, val_NLLs = _evaluate_perplexity(model, val_dl, vocab_src, vocab_tgt, device)
    val_NLL = val_NLLs['joint/main']
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src, vocab_tgt,
                                                  device, hparams)

    random_idx = np.random.choice(len(inputs))
    #nll_str = ' '.join('-- validation NLL {} = {:.2f}'.format(comp_name, comp_value)  for comp_name, comp_value in sorted(val_NLLs.items()))
    nll_str = f""
    # - log P(x|z) for the various source LM decoders
    for comp_name, comp_nll in sorted(val_NLLs.items()):
        if comp_name.startswith('lm/'):
            nll_str += f" -- {comp_name} = {comp_nll:,.2f}"
    # - log P(y|z,x) for the various translation decoders
    for comp_name, comp_nll in sorted(val_NLLs.items()):
        if comp_name.startswith('tm/'):
            nll_str += f" -- {comp_name} = {comp_nll:,.2f}"
    
    kl_str = f"-- KL = {val_KL.sum():.2f}"
    if isinstance(model.prior(), ProductOfDistributions):
        for i, p in enumerate(model.prior().distributions):
            kl_str += f" -- KL{i} = {val_KL[i]:.2f}" 
        
    print(f"direction = {title}\n"
          f"validation perplexity = {val_ppl:,.2f}"
          f" -- BLEU = {val_bleu:.2f}"
          f" {kl_str}"
          f" {nll_str}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    if hparams.draw_translations > 0:
        random_idx = np.random.choice(len(inputs))
        dl = DataLoader([val_data[random_idx] for _ in range(hparams.draw_translations)], batch_size=hparams.batch_size, shuffle=False, num_workers=4)
        dl = BucketingParallelDataLoader(dl)
        i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device, hparams)
        print("Posterior samples")
        print(f"- Input: {i[0]}")
        print(f"- Reference: {r[0]}")
        for h in hs:
            print(f"- Translation: {h}")
    
    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step)
        summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl, step)
        summary_writer.add_scalar(f"{title}/validation/KL", val_KL.sum(), step)
        if isinstance(model.prior(), ProductOfDistributions):
            for i, _ in enumerate(model.prior().distributions):
                summary_writer.add_scalar(f"{title}/validation/KL{i}", val_KL[i], step)
        for comp_name, comp_value in val_NLLs.items():
            summary_writer.add_scalar(f"{title}/validation/NLL/{comp_name}", comp_value, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x], vocab_src, device)
            y_in, y_out, seq_mask_y, seq_len_y = create_batch([val_sentence_y], vocab_tgt, device)
            z = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y).sample()
            _, _, state, _, _ = model(x_in, seq_mask_x, seq_len_x, y_in, z)
            att_weights = state['att_weights'].squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src, no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt, no_filter=True)[0].split()
        attention_summary(src_labels, tgt_labels, att_weights, summary_writer,
                          f"{title}/validation/attention", step)

    return {'bleu': val_bleu, 'likelihood': -val_NLL, 'nll': val_NLL, 'ppl': val_ppl}
コード例 #8
0
ファイル: aevnmt_helper.py プロジェクト: wilkeraziz/AEVNMT.pt
def validate(model,
             val_data,
             vocab_src,
             vocab_tgt,
             device,
             hparams,
             step,
             title='xy',
             summary_writer=None):
    model.eval()

    # Create the validation dataloader. We can just bucket.
    val_dl = DataLoader(val_data,
                        batch_size=hparams.batch_size,
                        shuffle=False,
                        num_workers=4)
    val_dl = BucketingParallelDataLoader(val_dl)

    val_ppl, val_NLL, val_KL = _evaluate_perplexity(model, val_dl, vocab_src,
                                                    vocab_tgt, device)
    val_bleu, inputs, refs, hyps = _evaluate_bleu(model, val_dl, vocab_src,
                                                  vocab_tgt, device, hparams)

    random_idx = np.random.choice(len(inputs))
    print(f"direction = {title}\n"
          f"validation perplexity = {val_ppl:,.2f}"
          f" -- validation NLL = {val_NLL:,.2f}"
          f" -- validation BLEU = {val_bleu:.2f}"
          f" -- validation KL = {val_KL:.2f}\n"
          f"- Source: {inputs[random_idx]}\n"
          f"- Target: {refs[random_idx]}\n"
          f"- Prediction: {hyps[random_idx]}")

    if hparams.draw_translations > 0:
        random_idx = np.random.choice(len(inputs))
        dl = DataLoader(
            [val_data[random_idx] for _ in range(hparams.draw_translations)],
            batch_size=hparams.batch_size,
            shuffle=False,
            num_workers=4)
        dl = BucketingParallelDataLoader(dl)
        i, r, hs = _draw_translations(model, dl, vocab_src, vocab_tgt, device,
                                      hparams)
        print("Posterior samples")
        print(f"- Input: {i[0]}")
        print(f"- Reference: {r[0]}")
        for h in hs:
            print(f"- Translation: {h}")

    # Write validation summaries.
    if summary_writer is not None:
        summary_writer.add_scalar(f"{title}/validation/NLL", val_NLL, step)
        summary_writer.add_scalar(f"{title}/validation/BLEU", val_bleu, step)
        summary_writer.add_scalar(f"{title}/validation/perplexity", val_ppl,
                                  step)
        summary_writer.add_scalar(f"{title}/validation/KL", val_KL, step)

        # Log the attention weights of the first validation sentence.
        with torch.no_grad():
            val_sentence_x, val_sentence_y = val_data[0]
            x_in, _, seq_mask_x, seq_len_x = create_batch([val_sentence_x],
                                                          vocab_src, device)
            y_in, y_out, _, _ = create_batch([val_sentence_y], vocab_tgt,
                                             device)
            z = model.approximate_posterior(x_in, seq_mask_x,
                                            seq_len_x).sample()
            _, _, att_weights = model(x_in, seq_mask_x, seq_len_x, y_in, z)
            att_weights = att_weights.squeeze().cpu().numpy()
        src_labels = batch_to_sentences(x_in, vocab_src,
                                        no_filter=True)[0].split()
        tgt_labels = batch_to_sentences(y_out, vocab_tgt,
                                        no_filter=True)[0].split()
        attention_summary(src_labels, tgt_labels, att_weights, summary_writer,
                          f"{title}/validation/attention", step)

    return {
        'bleu': val_bleu,
        'likelihood': -val_NLL,
        'nll': val_NLL,
        'ppl': val_ppl
    }