示例#1
0
def train_step(model, x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out,
               seq_mask_y, seq_len_y, noisy_y_in, hparams, step, summary_writer=None):


    # Use q(z|x) for training to sample a z.
    qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y)
    z = qz.rsample()

    # Compute the translation and language model logits.
    tm_likelihood, lm_likelihood, _, aux_lm_likelihoods, aux_tm_likelihoods = model(noisy_x_in, seq_mask_x, seq_len_x, noisy_y_in, z)

    # Do linear annealing of the KL over KL_annealing_steps if set.
    KL_weight = hparams.kl.weight
    if hparams.kl.annealing_steps > 0:
        KL_weight = min(KL_weight, (KL_weight / hparams.kl.annealing_steps) * step)

    # Compute the loss.
    loss_cfg = None
    #if step < 20000:
    #    loss_cfg = {'lm/bow', }  #, 'tm/bow'}
    #else:  #if step < 20000:
    #    loss_cfg = {'lm/made', } #, 'tm/made'}
    #elif step < 3000:
    #    loss_cfg = {'lm/made', 'lm/made_count', 'tm/made', 'tm/made_count'}
    #elif step < 4500:
    #    loss_cfg = {'lm/made', 'lm/made_count', 'lm/main', 'tm/made', 'tm/made_count', 'tm/main'}
    #else:
    #    loss_cfg = {'lm/main', 'tm/main'}
    loss = model.loss(tm_likelihood, lm_likelihood, y_out, x_out, qz,
                      free_nats=hparams.kl.free_nats,
                      KL_weight=KL_weight,
                      mmd_weight=hparams.loss.mmd_weight,
                      reduction="mean",
                      smoothing_x=hparams.gen.lm.label_smoothing,
                      smoothing_y=hparams.gen.tm.label_smoothing,
                      aux_lm_likelihoods=aux_lm_likelihoods,
                      aux_tm_likelihoods=aux_tm_likelihoods,
                      loss_cfg=loss_cfg)

    if summary_writer and step % hparams.print_every == 0:
        summary_writer.add_histogram("posterior/z", z, step)
        for param_name, param_value in get_named_params(qz):
            summary_writer.add_histogram("posterior/%s" % param_name, param_value, step)
        pz = model.prior()
        # This part is perhaps not necessary for a simple prior (e.g. Gaussian),
        #  but it's useful for more complex priors (e.g. mixtures and NFs)
        prior_sample = pz.sample(torch.Size([z.size(0)]))
        summary_writer.add_histogram("prior/z", prior_sample, step)
        for param_name, param_value in get_named_params(pz):
            summary_writer.add_histogram("prior/%s" % param_name, param_value, step)

    return loss
示例#2
0
    def forward(self, x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in, step):

        # for arg in (x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in):
        #     print(type(arg))
        #     if isinstance(arg, torch.Tensor):
        #         print(arg.shape)
        # print()
        # Inf model
        qz = self.model.approximate_posterior(x_in, seq_mask_x, seq_len_x, y_in, seq_mask_y, seq_len_y)
        z = qz.rsample()

        # Gen model
        tm_likelihood, lm_likelihood, _, aux_lm_likelihoods, aux_tm_likelihoods = self.model(noisy_x_in, seq_mask_x, seq_len_x, noisy_y_in, z)

        # Loss
        if self.hparams.kl.annealing_steps > 0:
            KL_weight = min(1., (1.0 / self.hparams.kl.annealing_steps) * step)
        else:
            KL_weight = 1.
            
        return_dict = self.model.loss(tm_likelihood, lm_likelihood, y_out, x_out, qz,
                                    free_nats=self.hparams.kl.free_nats,
                                    KL_weight=KL_weight,
                                    reduction="none",
                                    smoothing_x=self.hparams.gen.lm.label_smoothing,
                                    smoothing_y=self.hparams.gen.tm.label_smoothing,
                                    aux_lm_likelihoods=aux_lm_likelihoods,
                                    aux_tm_likelihoods=aux_tm_likelihoods,
                                    loss_cfg=None)

        # Add posterior values to return dict for Tensorboard
        return_dict['posterior/z'] = z.detach()
        for param_name, param_value in get_named_params(qz):
            return_dict["posterior/%s" % param_name] = param_value.detach()
        return return_dict
示例#3
0
def aevnmt_train_parallel(model, x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in,
                          hparams, step, summary_writer=None):
    return_dict = model(x_in, x_out, seq_mask_x, seq_len_x, noisy_x_in, y_in, y_out, seq_mask_y, seq_len_y, noisy_y_in, step)
    return_dict['loss'] = return_dict['loss'].mean()

    # To keep the summary a reasonable size, only save histograms every print_step
    if summary_writer and step % hparams.print_every == 0:
        for comp_name, comp_value in sorted(return_dict.items()):
            if comp_name.startswith('posterior/'):
                summary_writer.add_histogram(comp_name, comp_value, step)
        pz = model.module.model.prior()
        # This part is perhaps not necessary for a simple prior (e.g. Gaussian),
        #  but it's useful for more complex priors (e.g. mixtures and NFs)
        prior_sample = pz.sample(torch.Size([hparams.batch_size]))
        summary_writer.add_histogram("prior/z", prior_sample, step)
        for param_name, param_value in get_named_params(pz):
            summary_writer.add_histogram("prior/%s" % param_name, param_value, step)
    return return_dict
示例#4
0
    def step(self,
             x_in,
             x_out,
             seq_mask_x,
             seq_len_x,
             noisy_x_in,
             y_in,
             y_out,
             seq_mask_y,
             seq_len_y,
             noisy_y_in,
             step,
             summary_writer=None):
        q_z = self.model.approximate_posterior(x_in, seq_mask_x, seq_len_x,
                                               y_in, seq_mask_y, seq_len_y)
        p_z = self.model.prior()

        if self.loss.num_samples == 1:
            z = q_z.rsample()
        else:
            z = q_z.rsample([self.loss.num_samples])
            z = self._flatten_samples(z)  # [K * B, ...]

            # Expand required inputs to [K * B, ...]
            noisy_x_in = self._repeat_over_samples(noisy_x_in,
                                                   self.loss.num_samples,
                                                   flatten=True)
            noisy_y_in = self._repeat_over_samples(noisy_y_in,
                                                   self.loss.num_samples,
                                                   flatten=True)
            seq_mask_x = self._repeat_over_samples(seq_mask_x,
                                                   self.loss.num_samples,
                                                   flatten=True)
            seq_len_x = self._repeat_over_samples(seq_len_x,
                                                  self.loss.num_samples,
                                                  flatten=True)

            # Expand targets
            x_out = self._repeat_over_samples(x_out,
                                              self.loss.num_samples,
                                              flatten=True)
            y_out = self._repeat_over_samples(y_out,
                                              self.loss.num_samples,
                                              flatten=True)

        # TODO aux likelihoods are not used in the new Loss functions.
        tm_likelihood, lm_likelihood, _, _, _ = self.model(
            noisy_x_in, seq_mask_x, seq_len_x, noisy_y_in, z)

        loss_dict = self.loss(tm_likelihood,
                              lm_likelihood,
                              y_out,
                              x_out,
                              q_z,
                              p_z,
                              z,
                              step,
                              self.model,
                              reduction='mean')

        if summary_writer and step % self.histogram_every == 0:
            # generate histograms for the posterior and prior
            summary_writer.add_histogram("posterior/z", z, step)
            for param_name, param_value in get_named_params(q_z):
                summary_writer.add_histogram("posterior/%s" % param_name,
                                             param_value, step)
            prior_sample = p_z.sample(torch.Size([z.size(0)]))
            summary_writer.add_histogram("prior/z", prior_sample, step)
            for param_name, param_value in get_named_params(p_z):
                summary_writer.add_histogram("prior/%s" % param_name,
                                             param_value, step)

        return loss_dict