예제 #1
0
def monolingual_step(model_xy, model_yx, cycle_iterate_dl, mono_train_fn,
                     optimizers, vocab_src, vocab_tgt, config, step, device):
    if config["model_type"] == "coaevnmt":
        lm_switch = RequiresGradSwitch(model_xy.lm_parameters())
        if not config["update_lm"]:
            lm_switch.requires_grad(False)

    sentences_y = next(cycle_iterate_dl)

    y_in, y_out, y_mask, y_len, y_noisy_in = create_noisy_batch(
        sentences_y, vocab_tgt, device, word_dropout=config["word_dropout"])

    y_mask = y_mask.unsqueeze(1)

    y_mono_loss = mono_train_fn(model_xy, model_yx, y_in, y_len, y_mask, y_out,
                                vocab_src, config, step)
    y_mono_loss.backward()

    optimizer_step(model_xy.generative_parameters(), optimizers['gen'],
                   config["max_gradient_norm"])
    if config["model_type"] == "coaevnmt":
        optimizer_step(model_xy.inference_parameters(), optimizers['inf'],
                       config["max_gradient_norm"])

        if not config["update_lm"]:  # so we restore switches for source LM
            lm_switch.restore()
    return y_mono_loss.item()
예제 #2
0
def bilingual_step(model_xy, model_yx, sentences_x, sentences_y, bi_train_fn,
                   optimizers_xy, optimizers_yx, vocab_src, vocab_tgt, config,
                   step, device):
    x_in, x_out, x_mask, x_len, x_noisy_in = create_noisy_batch(
        sentences_x, vocab_src, device, word_dropout=config["word_dropout"])
    y_in, y_out, y_mask, y_len, y_noisy_in = create_noisy_batch(
        sentences_y, vocab_tgt, device, word_dropout=config["word_dropout"])

    x_mask = x_mask.unsqueeze(1)
    y_mask = y_mask.unsqueeze(1)

    # Bilingual loss

    bi_loss = bi_train_fn(model_xy, model_yx, x_in, x_noisy_in, x_out, x_len,
                          x_mask, y_in, y_noisy_in, y_out, y_len, y_mask, step)
    bi_loss.backward()

    optimizer_step(model_xy.generative_parameters(), optimizers_xy['gen'],
                   config["max_gradient_norm"])
    optimizer_step(model_yx.generative_parameters(), optimizers_yx['gen'],
                   config["max_gradient_norm"])
    if config["model_type"] == "coaevnmt":
        optimizer_step(model_xy.inference_parameters(), optimizers_xy['inf'],
                       config["max_gradient_norm"])
        optimizer_step(model_yx.inference_parameters(), optimizers_yx['inf'],
                       config["max_gradient_norm"])

    return bi_loss.item()