Ejemplo n.º 1
0
    def adv_train_generator(self, g_step, current_k=0):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """

        rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA)
        adv_mana_loss = 0
        adv_work_loss = 0
        for step in range(g_step):
            with torch.no_grad():
                gen_samples = self.gen.sample(
                    cfg.batch_size, cfg.batch_size, self.dis,
                    train=True)  # !!! train=True, the only place
                inp, target = GenDataIter.prepare(gen_samples, gpu=cfg.CUDA)

            # ===Train===
            rewards = rollout_func.get_reward_leakgan(
                target, cfg.rollout_num, self.dis,
                current_k).cpu()  # reward with MC search
            mana_loss, work_loss = self.gen.adversarial_loss(
                target, rewards, self.dis)

            # update parameters
            self.optimize_multi(self.gen_opt, [mana_loss, work_loss])
            adv_mana_loss += mana_loss.data.item()
            adv_work_loss += work_loss.data.item()
        # ===Test===
        self.log.info(
            '[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' %
            (adv_mana_loss / g_step, adv_work_loss / g_step,
             self.cal_metrics(fmt_str=True)))
Ejemplo n.º 2
0
    def adv_train_generator(self, g_step):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """
        rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA)
        total_g_loss = 0

        print('g_step in adv_train_generator->', g_step)

        for step in range(g_step):
            samples = self.gen.sample(cfg.batch_size, cfg.batch_size)
            print('samples ->', samples.size())

            inp, target = GenDataIter.prepare(samples, gpu=cfg.CUDA)
            print('inp ->', inp.size())
            print('target ->', target.size())

            # ===Train===
            rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis)
            print('rewards ->', rewards.size(), rewards)

            adv_loss = self.gen.batchPGLoss(inp, target, rewards)
            self.optimize(self.gen_adv_opt, adv_loss)
            total_g_loss += adv_loss.item()

        # ===Test===
        self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True)))
Ejemplo n.º 3
0
    def adv_train_generator(self, g_step):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """
        for i in range(cfg.k_label):
            rollout_func = rollout.ROLLOUT(self.gen_list[i], cfg.CUDA)
            total_g_loss = 0
            for step in range(g_step):
                inp, target = GenDataIter.prepare(self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA)

                # ===Train===
                rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis)
                adv_loss = self.gen_list[i].batchPGLoss(inp, target, rewards)
                self.optimize(self.gen_opt_list[i], adv_loss)
                total_g_loss += adv_loss.item()

        # ===Test===
        self.log.info('[ADV-GEN]: %s', self.comb_metrics(fmt_str=True))
Ejemplo n.º 4
0
    def adv_train_generator(self, g_step):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """
        rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA)
        total_g_loss = 0
        for step in range(g_step):
            inp, target = GenDataIter.prepare(
                self.gen.sample(cfg.batch_size, cfg.batch_size),
                gpu=cfg.CUDA)  #把G sample出来的句子切割成错位的
            # 这里有个问题,用这个数据集可以用来train Generator吗?
            # ===Train===
            rewards = rollout_func.get_reward(target, cfg.rollout_num,
                                              self.dis)
            adv_loss = self.gen.batchPGLoss(inp, target, rewards)
            self.optimize(self.gen_adv_opt, adv_loss)
            total_g_loss += adv_loss.item()

        # ===Test===
        self.log.info('[ADV-GEN]: g_loss = %.4f, %s' %
                      (total_g_loss, self.cal_metrics(fmt_str=True)))