def train_step(model, x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in, hparams, step, summary_writer=None): # Use q(z|x) for training to sample a z. qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y) z = qz.rsample() # Compute the translation and language model logits. tm_likelihood, lm_likelihood, _, aux_lm_likelihoods, aux_tm_likelihoods = model(noisy_x_in, seq_mask_x, seq_len_x, noisy_y_in, z) # Do linear annealing of the KL over KL_annealing_steps if set. KL_weight = hparams.kl.weight if hparams.kl.annealing_steps > 0: KL_weight = min(KL_weight, (KL_weight / hparams.kl.annealing_steps) * step) # Compute the loss. loss_cfg = None #if step < 20000: # loss_cfg = {'lm/bow', } #, 'tm/bow'} #else: #if step < 20000: # loss_cfg = {'lm/made', } #, 'tm/made'} #elif step < 3000: # loss_cfg = {'lm/made', 'lm/made_count', 'tm/made', 'tm/made_count'} #elif step < 4500: # loss_cfg = {'lm/made', 'lm/made_count', 'lm/main', 'tm/made', 'tm/made_count', 'tm/main'} #else: # loss_cfg = {'lm/main', 'tm/main'} loss = model.loss(tm_likelihood, lm_likelihood, y_out, x_out, qz, free_nats=hparams.kl.free_nats, KL_weight=KL_weight, mmd_weight=hparams.loss.mmd_weight, reduction="mean", smoothing_x=hparams.gen.lm.label_smoothing, smoothing_y=hparams.gen.tm.label_smoothing, aux_lm_likelihoods=aux_lm_likelihoods, aux_tm_likelihoods=aux_tm_likelihoods, loss_cfg=loss_cfg) if summary_writer and step % hparams.print_every == 0: summary_writer.add_histogram("posterior/z", z, step) for param_name, param_value in get_named_params(qz): summary_writer.add_histogram("posterior/%s" % param_name, param_value, step) pz = model.prior() # This part is perhaps not necessary for a simple prior (e.g. Gaussian), # but it's useful for more complex priors (e.g. mixtures and NFs) prior_sample = pz.sample(torch.Size([z.size(0)])) summary_writer.add_histogram("prior/z", prior_sample, step) for param_name, param_value in get_named_params(pz): summary_writer.add_histogram("prior/%s" % param_name, param_value, step) return loss
def forward(self, x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in, step): # for arg in (x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in): # print(type(arg)) # if isinstance(arg, torch.Tensor): # print(arg.shape) # print() # Inf model qz = self.model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y) z = qz.rsample() # Gen model tm_likelihood, lm_likelihood, _, aux_lm_likelihoods, aux_tm_likelihoods = self.model(noisy_x_in, seq_mask_x, seq_len_x, noisy_y_in, z) # Loss if self.hparams.kl.annealing_steps > 0: KL_weight = min(1., (1.0 / self.hparams.kl.annealing_steps) * step) else: KL_weight = 1. return_dict = self.model.loss(tm_likelihood, lm_likelihood, y_out, x_out, qz, free_nats=self.hparams.kl.free_nats, KL_weight=KL_weight, reduction="none", smoothing_x=self.hparams.gen.lm.label_smoothing, smoothing_y=self.hparams.gen.tm.label_smoothing, aux_lm_likelihoods=aux_lm_likelihoods, aux_tm_likelihoods=aux_tm_likelihoods, loss_cfg=None) # Add posterior values to return dict for Tensorboard return_dict['posterior/z'] = z.detach() for param_name, param_value in get_named_params(qz): return_dict["posterior/%s" % param_name] = param_value.detach() return return_dict
def aevnmt_train_parallel(model, x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in, hparams, step, summary_writer=None): return_dict = model(x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in, step) return_dict['loss'] = return_dict['loss'].mean() # To keep the summary a reasonable size, only save histograms every print_step if summary_writer and step % hparams.print_every == 0: for comp_name, comp_value in sorted(return_dict.items()): if comp_name.startswith('posterior/'): summary_writer.add_histogram(comp_name, comp_value, step) pz = model.module.model.prior() # This part is perhaps not necessary for a simple prior (e.g. Gaussian), # but it's useful for more complex priors (e.g. mixtures and NFs) prior_sample = pz.sample(torch.Size([hparams.batch_size])) summary_writer.add_histogram("prior/z", prior_sample, step) for param_name, param_value in get_named_params(pz): summary_writer.add_histogram("prior/%s" % param_name, param_value, step) return return_dict
def step(self, x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in, step, summary_writer=None): q_z = self.model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y) p_z = self.model.prior() if self.loss.num_samples == 1: z = q_z.rsample() else: z = q_z.rsample([self.loss.num_samples]) z = self._flatten_samples(z) # [K * B, ...] # Expand required inputs to [K * B, ...] noisy_x_in = self._repeat_over_samples(noisy_x_in, self.loss.num_samples, flatten=True) noisy_y_in = self._repeat_over_samples(noisy_y_in, self.loss.num_samples, flatten=True) seq_mask_x = self._repeat_over_samples(seq_mask_x, self.loss.num_samples, flatten=True) seq_len_x = self._repeat_over_samples(seq_len_x, self.loss.num_samples, flatten=True) # Expand targets x_out = self._repeat_over_samples(x_out, self.loss.num_samples, flatten=True) y_out = self._repeat_over_samples(y_out, self.loss.num_samples, flatten=True) # TODO aux likelihoods are not used in the new Loss functions. tm_likelihood, lm_likelihood, _, _, _ = self.model( noisy_x_in, seq_mask_x, seq_len_x, noisy_y_in, z) loss_dict = self.loss(tm_likelihood, lm_likelihood, y_out, x_out, q_z, p_z, z, step, self.model, reduction='mean') if summary_writer and step % self.histogram_every == 0: # generate histograms for the posterior and prior summary_writer.add_histogram("posterior/z", z, step) for param_name, param_value in get_named_params(q_z): summary_writer.add_histogram("posterior/%s" % param_name, param_value, step) prior_sample = p_z.sample(torch.Size([z.size(0)])) summary_writer.add_histogram("prior/z", prior_sample, step) for param_name, param_value in get_named_params(p_z): summary_writer.add_histogram("prior/%s" % param_name, param_value, step) return loss_dict