예제 #1
0
def bilingual_step(model_xy, model_yx, sentences_x, sentences_y, bi_train_fn,
                   optimizers_xy, optimizers_yx, vocab_src, vocab_tgt, config,
                   step, device):
    x_in, x_out, x_mask, x_len, x_noisy_in = create_noisy_batch(
        sentences_x, vocab_src, device, word_dropout=config["word_dropout"])
    y_in, y_out, y_mask, y_len, y_noisy_in = create_noisy_batch(
        sentences_y, vocab_tgt, device, word_dropout=config["word_dropout"])

    x_mask = x_mask.unsqueeze(1)
    y_mask = y_mask.unsqueeze(1)

    # Bilingual loss

    bi_loss = bi_train_fn(model_xy, model_yx, x_in, x_noisy_in, x_out, x_len,
                          x_mask, y_in, y_noisy_in, y_out, y_len, y_mask, step)
    bi_loss.backward()

    optimizer_step(model_xy.generative_parameters(), optimizers_xy['gen'],
                   config["max_gradient_norm"])
    optimizer_step(model_yx.generative_parameters(), optimizers_yx['gen'],
                   config["max_gradient_norm"])
    if config["model_type"] == "coaevnmt":
        optimizer_step(model_xy.inference_parameters(), optimizers_xy['inf'],
                       config["max_gradient_norm"])
        optimizer_step(model_yx.inference_parameters(), optimizers_yx['inf'],
                       config["max_gradient_norm"])

    return bi_loss.item()
예제 #2
0
def monolingual_step(model_xy, model_yx, cycle_iterate_dl, mono_train_fn,
                     optimizers, vocab_src, vocab_tgt, config, step, device):
    if config["model_type"] == "coaevnmt":
        lm_switch = RequiresGradSwitch(model_xy.lm_parameters())
        if not config["update_lm"]:
            lm_switch.requires_grad(False)

    sentences_y = next(cycle_iterate_dl)

    y_in, y_out, y_mask, y_len, y_noisy_in = create_noisy_batch(
        sentences_y, vocab_tgt, device, word_dropout=config["word_dropout"])

    y_mask = y_mask.unsqueeze(1)

    y_mono_loss = mono_train_fn(model_xy, model_yx, y_in, y_len, y_mask, y_out,
                                vocab_src, config, step)
    y_mono_loss.backward()

    optimizer_step(model_xy.generative_parameters(), optimizers['gen'],
                   config["max_gradient_norm"])
    if config["model_type"] == "coaevnmt":
        optimizer_step(model_xy.inference_parameters(), optimizers['inf'],
                       config["max_gradient_norm"])

        if not config["update_lm"]:  # so we restore switches for source LM
            lm_switch.restore()
    return y_mono_loss.item()
예제 #3
0
def bilingual_step(model, sentences_x, sentences_y, train_fn, optimizer,
                   vocab_src, vocab_tgt, config, step, device):
    x_in, x_out, x_mask, x_len, x_noisy_in = create_noisy_batch(
        sentences_x, vocab_src, device, word_dropout=config["word_dropout"])
    y_in, y_out, y_mask, y_len, y_noisy_in = create_noisy_batch(
        sentences_y, vocab_tgt, device, word_dropout=config["word_dropout"])

    x_mask = x_mask.unsqueeze(1)
    y_mask = y_mask.unsqueeze(1)
    optimizer.zero_grad()

    loss = train_fn(model, x_in, x_noisy_in, x_out, x_len, x_mask, y_in,
                    y_noisy_in, y_out, step)
    loss.backward()

    if config["max_gradient_norm"] > 0:
        clip_grad_norm_(model.parameters(), config["max_gradient_norm"])

    optimizer.step()
    return loss.item()
예제 #4
0
def mono_train_fn(model_xy, model_yx, y_in, y_noisy_in, y_len, y_mask, y_out,
                  vocab_src, config, step):
    device = torch.device(
        "cpu") if config["device"] == "cpu" else torch.device("cuda:0")
    with torch.no_grad():
        qz_y = model_yx.inference(y_in, y_mask, y_len)
        z_y = qz_y.sample()

        enc_output, enc_final = model_yx.encode(y_in, y_len, z_y)
        dec_hidden = model_yx.init_decoder(enc_output, enc_final, z_y)

        if config["decoding_method"] == "beam_search":
            x_samples = beam_search(model_yx.decoder,
                                    model_yx.emb_tgt,
                                    model_yx.generate_tm,
                                    enc_output,
                                    dec_hidden,
                                    y_mask,
                                    vocab_src.size(),
                                    vocab_src[SOS_TOKEN],
                                    vocab_src[EOS_TOKEN],
                                    vocab_src[PAD_TOKEN],
                                    config,
                                    beam_width=config["decoding_beam_width"],
                                    z=z_y)
        else:
            greedy = False if config["decoding_method"] == "ancestral" else True
            x_samples = ancestral_sample(model_yx.decoder,
                                         model_yx.emb_tgt,
                                         model_yx.generate_tm,
                                         enc_output,
                                         dec_hidden,
                                         y_mask,
                                         vocab_src[SOS_TOKEN],
                                         vocab_src[EOS_TOKEN],
                                         vocab_src[PAD_TOKEN],
                                         config,
                                         greedy=greedy,
                                         z=z_y)
        x_samples = batch_to_sentences(x_samples, vocab_src)
        x_in, x_out, x_mask, x_len, x_noisy_in = create_noisy_batch(
            x_samples, vocab_src, device, word_dropout=config["word_dropout"])
        x_mask = x_mask.unsqueeze(1)

    qz_x = model_xy.inference(x_in, x_mask, x_len)
    z_x = qz_x.rsample()
    tm_logits, lm_logits, _, _ = model_xy(x_noisy_in, x_len, x_mask,
                                          y_noisy_in, z_x)

    loss = model_xy.loss(tm_logits, lm_logits, None, None, y_out, x_out, qz_x,
                         step)
    return loss