def bilingual_step(model_xy, model_yx, sentences_x, sentences_y, bi_train_fn, optimizers_xy, optimizers_yx, vocab_src, vocab_tgt, config, step, device): x_in, x_out, x_mask, x_len, x_noisy_in = create_noisy_batch( sentences_x, vocab_src, device, word_dropout=config["word_dropout"]) y_in, y_out, y_mask, y_len, y_noisy_in = create_noisy_batch( sentences_y, vocab_tgt, device, word_dropout=config["word_dropout"]) x_mask = x_mask.unsqueeze(1) y_mask = y_mask.unsqueeze(1) # Bilingual loss bi_loss = bi_train_fn(model_xy, model_yx, x_in, x_noisy_in, x_out, x_len, x_mask, y_in, y_noisy_in, y_out, y_len, y_mask, step) bi_loss.backward() optimizer_step(model_xy.generative_parameters(), optimizers_xy['gen'], config["max_gradient_norm"]) optimizer_step(model_yx.generative_parameters(), optimizers_yx['gen'], config["max_gradient_norm"]) if config["model_type"] == "coaevnmt": optimizer_step(model_xy.inference_parameters(), optimizers_xy['inf'], config["max_gradient_norm"]) optimizer_step(model_yx.inference_parameters(), optimizers_yx['inf'], config["max_gradient_norm"]) return bi_loss.item()
def monolingual_step(model_xy, model_yx, cycle_iterate_dl, mono_train_fn, optimizers, vocab_src, vocab_tgt, config, step, device): if config["model_type"] == "coaevnmt": lm_switch = RequiresGradSwitch(model_xy.lm_parameters()) if not config["update_lm"]: lm_switch.requires_grad(False) sentences_y = next(cycle_iterate_dl) y_in, y_out, y_mask, y_len, y_noisy_in = create_noisy_batch( sentences_y, vocab_tgt, device, word_dropout=config["word_dropout"]) y_mask = y_mask.unsqueeze(1) y_mono_loss = mono_train_fn(model_xy, model_yx, y_in, y_len, y_mask, y_out, vocab_src, config, step) y_mono_loss.backward() optimizer_step(model_xy.generative_parameters(), optimizers['gen'], config["max_gradient_norm"]) if config["model_type"] == "coaevnmt": optimizer_step(model_xy.inference_parameters(), optimizers['inf'], config["max_gradient_norm"]) if not config["update_lm"]: # so we restore switches for source LM lm_switch.restore() return y_mono_loss.item()
def bilingual_step(model, sentences_x, sentences_y, train_fn, optimizer, vocab_src, vocab_tgt, config, step, device): x_in, x_out, x_mask, x_len, x_noisy_in = create_noisy_batch( sentences_x, vocab_src, device, word_dropout=config["word_dropout"]) y_in, y_out, y_mask, y_len, y_noisy_in = create_noisy_batch( sentences_y, vocab_tgt, device, word_dropout=config["word_dropout"]) x_mask = x_mask.unsqueeze(1) y_mask = y_mask.unsqueeze(1) optimizer.zero_grad() loss = train_fn(model, x_in, x_noisy_in, x_out, x_len, x_mask, y_in, y_noisy_in, y_out, step) loss.backward() if config["max_gradient_norm"] > 0: clip_grad_norm_(model.parameters(), config["max_gradient_norm"]) optimizer.step() return loss.item()
def mono_train_fn(model_xy, model_yx, y_in, y_noisy_in, y_len, y_mask, y_out, vocab_src, config, step): device = torch.device( "cpu") if config["device"] == "cpu" else torch.device("cuda:0") with torch.no_grad(): qz_y = model_yx.inference(y_in, y_mask, y_len) z_y = qz_y.sample() enc_output, enc_final = model_yx.encode(y_in, y_len, z_y) dec_hidden = model_yx.init_decoder(enc_output, enc_final, z_y) if config["decoding_method"] == "beam_search": x_samples = beam_search(model_yx.decoder, model_yx.emb_tgt, model_yx.generate_tm, enc_output, dec_hidden, y_mask, vocab_src.size(), vocab_src[SOS_TOKEN], vocab_src[EOS_TOKEN], vocab_src[PAD_TOKEN], config, beam_width=config["decoding_beam_width"], z=z_y) else: greedy = False if config["decoding_method"] == "ancestral" else True x_samples = ancestral_sample(model_yx.decoder, model_yx.emb_tgt, model_yx.generate_tm, enc_output, dec_hidden, y_mask, vocab_src[SOS_TOKEN], vocab_src[EOS_TOKEN], vocab_src[PAD_TOKEN], config, greedy=greedy, z=z_y) x_samples = batch_to_sentences(x_samples, vocab_src) x_in, x_out, x_mask, x_len, x_noisy_in = create_noisy_batch( x_samples, vocab_src, device, word_dropout=config["word_dropout"]) x_mask = x_mask.unsqueeze(1) qz_x = model_xy.inference(x_in, x_mask, x_len) z_x = qz_x.rsample() tm_logits, lm_logits, _, _ = model_xy(x_noisy_in, x_len, x_mask, y_noisy_in, z_x) loss = model_xy.loss(tm_logits, lm_logits, None, None, y_out, x_out, qz_x, step) return loss