Пример #1
0
    def adv_train_discriminator(self, d_step):
        total_loss = 0
        for step in range(d_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)
            if self.norm == 'gp':
                gp_loss = cacl_gradient_penalty(self.dis, real_samples,
                                                gen_samples)
                d_loss += gp_loss * 10

            self.optimize(self.dis_opt, d_loss, self.dis)
            total_loss += d_loss.item()

        return total_loss / d_step if d_step != 0 else 0
Пример #2
0
    def adv_train_discriminator(self, d_step):
        total_loss = 0
        for step in range(d_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            if cfg.GP:
                gradient_penalty = self.calc_gradient_penalty(
                    real_samples.data, gen_samples.data)
                d_loss = d_loss + cfg.LAMBDA * gradient_penalty
            # print(d_loss.shape)
            self.optimize(self.dis_opt, d_loss, self.dis)
            total_loss += d_loss.item()

        return total_loss / d_step if d_step != 0 else 0
Пример #3
0
    def adv_train_discriminator(self, d_step, adv_epoch):
        total_loss = 0
        total_acc = 0
        for step in range(d_step):
            # TODO(ethanjiang) we may want to train a full epoch instead of a random batch
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # =====Train=====
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.dis_opt, d_loss, self.dis)

            total_loss += d_loss.item()
            predictions = torch.cat((d_out_real, d_out_fake))
            labels = torch.cat(
                (torch.ones_like(d_out_real), torch.zeros_like(d_out_fake)))
            total_acc += torch.sum(
                ((predictions > 0).float() == labels)).item()

        # =====Test=====
        avg_loss = total_loss / d_step if d_step != 0 else 0
        avg_acc = total_acc / (d_step * cfg.batch_size *
                               2) if d_step != 0 else 0
        if adv_epoch % cfg.adv_log_step == 0:
            self.log.info('[ADV-DIS] d_loss = %.4f, train_acc = %.4f,' %
                          (avg_loss, avg_acc))
Пример #4
0
    def adv_train_generator(self, g_step):
        criterion = nn.BCELoss()
        total_loss = 0
        with torch.no_grad():
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                gen_samples = gen_samples.cuda()
            D0 = torch.sigmoid(self.dis_D(gen_samples))
            P0 = (1. - D0) / torch.clamp(D0, min=1e-7)

        for step in range(g_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            real_label = torch.full((D0.shape[0], ), 1.)
            fake_label = torch.full((D0.shape[0], ), 0.)
            if cfg.CUDA:
                real_samples, gen_samples, real_label, fake_label = real_samples.cuda(
                ), gen_samples.cuda(), real_label.cuda(), fake_label.cuda()
            # print(self.dis_D(real_samples).shape, real_label.shape)
            errDD_real = criterion(torch.sigmoid(self.dis_D(real_samples)),
                                   real_label)
            errDD_fake = criterion(
                torch.sigmoid(self.dis_D(gen_samples.detach())), fake_label)
            self.optimize(self.dis_D_opt, errDD_real + errDD_fake, self.dis_D)

            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True).cuda()
            real_samples = F.one_hot(self.train_data.random_batch()['target'],
                                     cfg.vocab_size).float().cuda()
            D1 = torch.sigmoid(self.dis_D(gen_samples))
            P1 = (1. - D1)
            ratio = (P1 / torch.clamp(D1 * P0, min=1e-7))
            ratio_clipped = torch.clamp(ratio, 1.0 - cfg.clip_param,
                                        1.0 + cfg.clip_param)
            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            surr1 = ratio * d_out_fake
            surr2 = ratio_clipped * d_out_fake
            target = torch.where(d_out_fake > 0, torch.min(surr1, surr2),
                                 torch.max(surr1, surr2))
            g_loss, _ = get_losses(d_out_real, target, cfg.loss_type)
            # g_loss = -d_out_fake.mean()

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0
Пример #5
0
    def adv_train_generator(self, g_step):
        total_loss = 0
        for step in range(g_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

            # =====Train=====
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0
Пример #6
0
    def _environment_function(self, input):
        """
        The environment function that computes the loss for the samples with respect to the true label.

        :param input: input that will be evaluated. Shape: batch_size * seq_len * vocab_size
        :return g_loss: the loss for the samples with respect to the true label. Shape: batch_size
        """
        d_out_real = None
        d_out_fake = self.discriminator(input)
        if cfg.loss_type == 'rsgan':
            d_out_real = self.discriminator(self.real_samples)
        g_loss, _ = get_losses(d_out_real,
                               d_out_fake,
                               cfg.loss_type,
                               reduction='none')
        if self.num_rep != 1:
            g_loss = torch.mean(g_loss.reshape(self.batch_size, self.num_rep),
                                dim=1)
        return g_loss
Пример #7
0
    def evaluation(self, eval_type):
        """Evaluation all children, update child score. Note that the eval data should be the same"""
        eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size,
                                       cfg.max_bn * cfg.batch_size)
        gen_data = GenDataIter(eval_samples)

        # Fd
        if cfg.lambda_fd != 0:
            Fd = NLL.cal_nll(self.gen, gen_data.loader,
                             self.mle_criterion)  # NLL_div
        else:
            Fd = 0

        # Fq
        if eval_type == 'standard':
            Fq = self.eval_d_out_fake.mean().cpu().item()
        elif eval_type == 'rsgan':
            g_loss, d_loss = get_losses(self.eval_d_out_real,
                                        self.eval_d_out_fake, 'rsgan')
            Fq = d_loss.item()
        elif 'bleu' in eval_type:
            self.bleu.reset(
                test_text=tensor_to_tokens(eval_samples, self.idx2word_dict))

            if cfg.lambda_fq != 0:
                Fq = self.bleu.get_score(given_gram=int(eval_type[-1]))
            else:
                Fq = 0
        elif 'Ra' in eval_type:
            g_loss = torch.sigmoid(self.eval_d_out_fake -
                                   torch.mean(self.eval_d_out_real)).sum()
            Fq = g_loss.item()
        else:
            raise NotImplementedError("Evaluation '%s' is not implemented" %
                                      eval_type)

        score = cfg.lambda_fq * Fq + cfg.lambda_fd * Fd
        return Fq, Fd, score
Пример #8
0
    def evaluation(self, eval_type):
        """Evaluation all children, update child score. Note that the eval data should be the same"""

        eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size,
                                       cfg.max_bn * cfg.batch_size)
        gen_data = GenDataIter(eval_samples)

        # Fd
        if cfg.lambda_fd != 0:
            Fd = NLL.cal_nll(self.gen, gen_data.loader,
                             self.mle_criterion)  # NLL_div
        else:
            Fd = 0

        if eval_type == 'standard':
            Fq = self.eval_d_out_fake.mean().cpu().item()
        elif eval_type == 'rsgan':
            g_loss, d_loss = get_losses(self.eval_d_out_real,
                                        self.eval_d_out_fake, 'rsgan')
            Fq = d_loss.item()
        elif eval_type == 'nll':
            if cfg.lambda_fq != 0:
                Fq = -NLL.cal_nll(self.oracle, gen_data.loader,
                                  self.mle_criterion)  # NLL_Oracle
            else:
                Fq = 0
        elif eval_type == 'Ra':
            g_loss = torch.sigmoid(self.eval_d_out_fake -
                                   torch.mean(self.eval_d_out_real)).sum()
            Fq = g_loss.item()
        else:
            raise NotImplementedError("Evaluation '%s' is not implemented" %
                                      eval_type)

        score = cfg.lambda_fq * Fq + cfg.lambda_fd * Fd
        return Fq, Fd, score
Пример #9
0
    def adv_train_generator(self, g_step, adv_epoch):
        # true_ge = TrueGradientEstimator()  TODO

        total_loss = 0
        for step in range(g_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # =====Train=====
            # vanilla_theta = self.gen.sample_vanilla_theta()
            # true_ge = true_ge.estimate_gradient(vanilla_theta...)  TODO

            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            theta_gradient = self.optimize(
                self.gen_adv_opt,
                g_loss,
                self.gen,
                theta_gradient_fetcher=self.gen.get_theta_gradient)
            theta_gradient_log_var = get_gradient_variance(theta_gradient)
            total_loss += g_loss.item()

        # =====Test=====
        avg_loss = total_loss / g_step if g_step != 0 else 0
        if adv_epoch % cfg.adv_log_step == 0:
            self.log.info(
                '[ADV-GEN] g_loss = %.4f, temperature = %.4f, theta_gradient_log_var = %.4f'
                % (avg_loss, self.gen.temperature, theta_gradient_log_var))