Пример #1
0
    def __init__(self, opt):
        super(RelGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = RelGAN_G(cfg.mem_slots,
                            cfg.num_heads,
                            cfg.head_size,
                            cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
        self.dis = RelGAN_D(cfg.dis_embed_dim,
                            cfg.max_seq_len,
                            cfg.num_rep,
                            cfg.vocab_size,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)

        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr)
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        # Criterion
        self.mle_criterion = nn.NLLLoss()

        # DataLoader
        self.gen_data = GenDataIter(
            self.gen.sample(cfg.batch_size, cfg.batch_size))
    def __init__(self, opt):
        super(RelGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = RelGAN_G(cfg.mem_slots,
                            cfg.num_heads,
                            cfg.head_size,
                            cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
        self.dis = RelGAN_D(cfg.dis_embed_dim,
                            cfg.max_seq_len,
                            cfg.num_rep,
                            cfg.vocab_size,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)

        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr)
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)
    def __init__(self, opt):
        super(RelGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim,
                            cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA)
        self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx,
                            gpu=cfg.CUDA)
        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr)
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        # Criterion
        self.mle_criterion = nn.NLLLoss()
        self.adv_criterion = nn.BCEWithLogitsLoss()

        # DataLoader
        self.gen_data = GenDataIter(self.gen.sample(cfg.batch_size, cfg.batch_size))

        # Metrics
        self.bleu = BLEU(test_text=tensor_to_tokens(self.gen_data.target, self.index_word_dict),
                         real_text=tensor_to_tokens(self.test_data.target, self.test_data.index_word_dict),
                         gram=[2, 3, 4, 5])
        self.self_bleu = BLEU(test_text=tensor_to_tokens(self.gen_data.target, self.index_word_dict),
                              real_text=tensor_to_tokens(self.gen_data.target, self.index_word_dict),
                              gram=3)
Пример #4
0
    def __init__(self, opt):
        super(RelGANInstructor, self).__init__(opt)
        norm = opt.norm
        assert norm in ['none', 'spectral', 'gradnorm', 'gp']
        # generator, discriminator
        print('norm ', norm)
        self.norm = norm
        self.gen = RelGAN_G(cfg.mem_slots,
                            cfg.num_heads,
                            cfg.head_size,
                            cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
        self.dis = RelGAN_D(cfg.dis_embed_dim,
                            cfg.max_seq_len,
                            cfg.num_rep,
                            cfg.vocab_size,
                            cfg.padding_idx,
                            gpu=cfg.CUDA,
                            norm=norm).cuda()
        if norm == 'gradnorm':
            print('use gradnorm')
            self.dis = GradNorm(self.dis).cuda()

        self.init_model()

        # Optimizer
        if norm == 'gradnorm':
            self.gen_opt = optim.Adam(self.gen.parameters(),
                                      lr=cfg.gen_lr,
                                      betas=(0.5, 0.999))
            self.gen_adv_opt = optim.Adam(self.gen.parameters(),
                                          lr=cfg.gen_adv_lr,
                                          betas=(0.5, 0.999))
            self.dis_opt = optim.Adam(self.dis.parameters(),
                                      lr=cfg.dis_lr,
                                      betas=(0.5, 0.999))
        else:
            self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
            self.gen_adv_opt = optim.Adam(self.gen.parameters(),
                                          lr=cfg.gen_adv_lr)
            self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        os.makedirs(cfg.log_filename.replace('.txt', ''), exist_ok=True)

        self.logger = SummaryWriter(
            cfg.log_filename.replace('.txt', '') + '_' + norm)
Пример #5
0
class RelGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(RelGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim,
                            cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA)
        self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx,
                            gpu=cfg.CUDA)

        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr)
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        # Criterion
        self.mle_criterion = nn.NLLLoss()

        # DataLoader
        self.gen_data = GenDataIter(self.gen.sample(cfg.batch_size, cfg.batch_size))

    def _run(self):
        # =====PRE-TRAINING (GENERATOR)=====
        if not cfg.gen_pretrain:
            self.log.info('Starting Generator MLE Training...')
            self.pretrain_generator(cfg.MLE_train_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.gen.state_dict(), cfg.pretrained_gen_path)
                print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path))

        self.log.info('Initial generator: %s' % (self.cal_metrics(fmt_str=True)))

        # # =====ADVERSARIAL TRAINING=====
        self.log.info('Starting Adversarial Training...')
        progress = tqdm(range(cfg.ADV_train_epoch))
        for adv_epoch in progress:
            self.sig.update()
            if self.sig.adv_sig:
                g_loss = self.adv_train_generator(cfg.ADV_g_step)  # Generator
                d_loss = self.adv_train_discriminator(cfg.ADV_d_step)  # Discriminator
                self.update_temperature(adv_epoch, cfg.ADV_train_epoch)  # update temperature

                progress.set_description(
                    'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature))

                # TEST
                if adv_epoch % cfg.adv_log_step == 0:
                    self.log.info('[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s' % (
                        adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True)))

                    if cfg.if_save and not cfg.if_test:
                        self._save('ADV', adv_epoch)
            else:
                self.log.info('>>> Stop by adv_signal! Finishing adversarial training...')
                progress.close()
                break

    def _test(self):
        print('>>> Begin test...')

        self._run()
        pass

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pre-training for the generator
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                # =====Train=====
                pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt)

                # =====Test=====
                if epoch % cfg.pre_log_step == 0:
                    self.log.info(
                        '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True)))

                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info('>>> Stop by pre signal, skip to adversarial training...')
                break
        if cfg.if_save and not cfg.if_test:
            self._save('MLE', epoch)

    def adv_train_generator(self, g_step):
        total_loss = 0
        for step in range(g_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

            # =====Train=====
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0

    def adv_train_discriminator(self, d_step):
        total_loss = 0
        for step in range(d_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

            # =====Train=====
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.dis_opt, d_loss, self.dis)
            total_loss += d_loss.item()

        return total_loss / d_step if d_step != 0 else 0

    def update_temperature(self, i, N):
        self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)

    @staticmethod
    def optimize(opt, loss, model=None, retain_graph=False):
        """Add clip_grad_norm_"""
        opt.zero_grad()
        loss.backward(retain_graph=retain_graph)
        if model is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm)
        opt.step()
Пример #6
0
class TRGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(TRGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = RelGAN_G(cfg.mem_slots, cfg.num_heads, cfg.head_size, cfg.gen_embed_dim, cfg.gen_hidden_dim,
                            cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA)
        self.dis = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx,
                            gpu=cfg.CUDA)
        self.dis_D = RelGAN_D(cfg.dis_embed_dim, cfg.max_seq_len, cfg.num_rep, cfg.vocab_size, cfg.padding_idx,
                            gpu=cfg.CUDA)

        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr)
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)
        self.dis_D_opt = optim.Adam(self.dis_D.parameters(), lr=cfg.dis_D_lr)

    def init_model(self):
        if cfg.oracle_pretrain:
            if not os.path.exists(cfg.oracle_state_dict_path):
                create_oracle()
            self.oracle.load_state_dict(torch.load(cfg.oracle_state_dict_path))

        if cfg.dis_pretrain:
            self.log.info(
                'Load pretrained discriminator: {}'.format(cfg.pretrained_dis_path))
            self.dis.load_state_dict(torch.load(cfg.pretrained_dis_path))
        if cfg.gen_pretrain:
            self.log.info('Load MLE pretrained generator gen: {}'.format(cfg.pretrained_gen_path))
            self.gen.load_state_dict(torch.load(cfg.pretrained_gen_path, map_location='cuda:{}'.format(cfg.device)))

        if cfg.CUDA:
            self.oracle = self.oracle.cuda()
            self.gen = self.gen.cuda()
            self.dis = self.dis.cuda()
            self.dis_D = self.dis_D.cuda()

    def _run(self):
        # ===PRE-TRAINING (GENERATOR)===
        if not cfg.gen_pretrain:
            self.log.info('Starting Generator MLE Training...')
            self.pretrain_generator(cfg.MLE_train_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.gen.state_dict(), cfg.pretrained_gen_path)
                print('Save pre-trained generator: {}'.format(cfg.pretrained_gen_path))

        # # ===ADVERSARIAL TRAINING===
        self.log.info('Starting Adversarial Training...')
        progress = tqdm(range(cfg.ADV_train_epoch))
        for adv_epoch in progress:
            self.sig.update()
            if self.sig.adv_sig:
                g_loss = self.adv_train_generator(cfg.ADV_g_step)  # Generator
                d_loss = self.adv_train_discriminator(cfg.ADV_d_step)  # Discriminator
                self.update_temperature(adv_epoch, cfg.ADV_train_epoch)  # update temperature

                progress.set_description(
                    'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' % (g_loss, d_loss, self.gen.temperature))

                # TEST
                if adv_epoch % cfg.adv_log_step == 0:
                    self.log.info('[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f, %s' % (
                        adv_epoch, g_loss, d_loss, self.cal_metrics(fmt_str=True)))

                    if cfg.if_save and not cfg.if_test:
                        self._save('ADV', adv_epoch)
            else:
                self.log.info('>>> Stop by adv_signal! Finishing adversarial training...')
                progress.close()
                break

    def _test(self):
        print('>>> Begin test...')

        self._run()

        pass

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pre-training for the generator
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                # ===Train===
                pre_loss = self.train_gen_epoch(self.gen, self.oracle_data.loader, self.mle_criterion, self.gen_opt)

                # ===Test===
                if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1:
                    self.log.info(
                        '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' % (epoch, pre_loss, self.cal_metrics(fmt_str=True)))

                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info('>>> Stop by pre signal, skip to adversarial training...')
                break

    def adv_train_generator(self, g_step):
        criterion = nn.BCELoss()
        total_loss = 0
        with torch.no_grad():
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                gen_samples = gen_samples.cuda()
            D0 = torch.sigmoid(self.dis_D(gen_samples))
            P0 = (1.-D0)/torch.clamp(D0, min = 1e-7)

        for step in range(g_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            real_label = torch.full((D0.shape[0],), 1.)
            fake_label = torch.full((D0.shape[0],), 0.)
            if cfg.CUDA:
                real_samples, gen_samples, real_label, fake_label = real_samples.cuda(), gen_samples.cuda(), real_label.cuda(), fake_label.cuda()
            # print(self.dis_D(real_samples).shape, real_label.shape)
            errDD_real = criterion(torch.sigmoid(self.dis_D(real_samples)), real_label)
            errDD_fake = criterion(torch.sigmoid(self.dis_D(gen_samples.detach())), fake_label)
            self.optimize(self.dis_D_opt, errDD_real+errDD_fake, self.dis_D)

            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True).cuda()
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float().cuda()
            D1 = torch.sigmoid(self.dis_D(gen_samples))
            P1 = (1.-D1)
            ratio = (P1/torch.clamp(D1*P0, min = 1e-7))
            ratio_clipped = torch.clamp(ratio, 1.0 - cfg.clip_param, 1.0 + cfg.clip_param)
            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            surr1 = ratio * d_out_fake
            surr2 = ratio_clipped * d_out_fake
            target = torch.where(d_out_fake>0, torch.min(surr1, surr2), torch.max(surr1, surr2))
            g_loss, _ = get_losses(d_out_real, target, cfg.loss_type)
            # g_loss = -d_out_fake.mean()

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0

    def calc_gradient_penalty(self, real_data, fake_data):
        BATCH_SIZE = real_data.shape[0]
        alpha = torch.rand(BATCH_SIZE, 1)
        alpha = alpha.expand(BATCH_SIZE, real_data.nelement()//BATCH_SIZE).contiguous().view(real_data.shape)
        alpha = alpha.cuda()

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)

        interpolates = interpolates.cuda()
        interpolates = autograd.Variable(interpolates, requires_grad=True)

    #     disc_interpolates = netD(interpolates)
        disc_interpolates = self.dis(interpolates)

        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
                                create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients = gradients.contiguous().view(gradients.size(0), -1)

        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def adv_train_discriminator(self, d_step):
        total_loss = 0
        for step in range(d_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            if cfg.GP:
                gradient_penalty = self.calc_gradient_penalty(real_samples.data, gen_samples.data)
                d_loss = d_loss+cfg.LAMBDA*gradient_penalty
            # print(d_loss.shape)
            self.optimize(self.dis_opt, d_loss, self.dis)
            total_loss += d_loss.item()

        return total_loss / d_step if d_step != 0 else 0

    def update_temperature(self, i, N):
        self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N, cfg.temp_adpt)

    @staticmethod
    def optimize(opt, loss, model=None, retain_graph=False):
        """Add clip_grad_norm_"""
        opt.zero_grad()
        loss.backward(retain_graph=retain_graph)
        if model is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm)
        opt.step()
Пример #7
0
class RelGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(RelGANInstructor, self).__init__(opt)
        norm = opt.norm
        assert norm in ['none', 'spectral', 'gradnorm', 'gp']
        # generator, discriminator
        print('norm ', norm)
        self.norm = norm
        self.gen = RelGAN_G(cfg.mem_slots,
                            cfg.num_heads,
                            cfg.head_size,
                            cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
        self.dis = RelGAN_D(cfg.dis_embed_dim,
                            cfg.max_seq_len,
                            cfg.num_rep,
                            cfg.vocab_size,
                            cfg.padding_idx,
                            gpu=cfg.CUDA,
                            norm=norm).cuda()
        if norm == 'gradnorm':
            print('use gradnorm')
            self.dis = GradNorm(self.dis).cuda()

        self.init_model()

        # Optimizer
        if norm == 'gradnorm':
            self.gen_opt = optim.Adam(self.gen.parameters(),
                                      lr=cfg.gen_lr,
                                      betas=(0.5, 0.999))
            self.gen_adv_opt = optim.Adam(self.gen.parameters(),
                                          lr=cfg.gen_adv_lr,
                                          betas=(0.5, 0.999))
            self.dis_opt = optim.Adam(self.dis.parameters(),
                                      lr=cfg.dis_lr,
                                      betas=(0.5, 0.999))
        else:
            self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
            self.gen_adv_opt = optim.Adam(self.gen.parameters(),
                                          lr=cfg.gen_adv_lr)
            self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        os.makedirs(cfg.log_filename.replace('.txt', ''), exist_ok=True)

        self.logger = SummaryWriter(
            cfg.log_filename.replace('.txt', '') + '_' + norm)

    def _run(self):
        # ===PRE-TRAINING (GENERATOR)===
        if os.path.exists(cfg.pretrained_gen_path):
            checkpoint = torch.load(cfg.pretrained_gen_path)
            generation_weights = self.gen.state_dict()
            match = True
            for key, value in checkpoint.items():
                if key not in generation_weights:
                    match = False
                elif generation_weights[key].shape != checkpoint[key].shape:
                    match = False
            if match:
                self.gen.load_state_dict(checkpoint)
                print('Load pre-trained generator: {}'.format(
                    cfg.pretrained_gen_path))
        elif not cfg.gen_pretrain:
            self.log.info('Starting Generator MLE Training...')
            self.pretrain_generator(cfg.MLE_train_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.gen.state_dict(), cfg.pretrained_gen_path)
                print('Save pre-trained generator: {}'.format(
                    cfg.pretrained_gen_path))

        # # ===ADVERSARIAL TRAINING===
        self.log.info('Starting Adversarial Training...')

        progress = tqdm(range(cfg.ADV_train_epoch), dynamic_ncols=True)

        for adv_epoch in progress:
            self.sig.update()
            if self.sig.adv_sig:
                start = time()
                g_loss = self.adv_train_generator(cfg.ADV_g_step)  # Generator
                d_loss = self.adv_train_discriminator(
                    cfg.ADV_d_step)  # Discriminator
                self.update_temperature(
                    adv_epoch, cfg.ADV_train_epoch)  # update temperature

                progress.set_description(
                    'g_loss: %.4f, d_loss: %.4f, temperature: %.4f' %
                    (g_loss, d_loss, self.gen.temperature))
                if adv_epoch % 10 == 0:
                    self.logger.add_scalar('train/d_loss', float(d_loss),
                                           adv_epoch)
                    self.logger.add_scalar('train/g_loss', float(g_loss),
                                           adv_epoch)
                    self.logger.add_scalar('train/temperature',
                                           self.gen.temperature, adv_epoch)

                # TEST
                if adv_epoch % cfg.adv_log_step == 0:
                    metrics = self.cal_metrics(fmt_str=False)
                    for key, value in metrics.items():
                        if isinstance(value, list):
                            for idx, v in enumerate(value):
                                self.logger.add_scalar(
                                    'train/' + key + '/' + str(idx), v,
                                    adv_epoch)
                        else:
                            self.logger.add_scalar('train/' + key, value,
                                                   adv_epoch)

                    self.logger.flush()
                    self.log.info(
                        '[ADV] epoch %d: g_loss: %.4f, d_loss: %.4f' %
                        (adv_epoch, g_loss, d_loss))

                    if cfg.if_save and not cfg.if_test:
                        self._save('GEN', adv_epoch)
            else:
                self.log.info(
                    '>>> Stop by adv_signal! Finishing adversarial training...'
                )
                progress.close()
                break

    def _test(self):
        print('>>> Begin test...')

        self._run()
        pass

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pre-training for the generator
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                # ===Train===
                pre_loss = self.train_gen_epoch(self.gen,
                                                self.train_data.loader,
                                                self.mle_criterion,
                                                self.gen_opt)

                # ===Test===
                if (epoch % cfg.pre_log_step == 0
                        or epoch == epochs - 1) and epoch > 0:
                    metrics = self.cal_metrics(fmt_str=False)
                    for key, value in metrics.items():
                        if isinstance(value, list):
                            for idx, v in enumerate(value):
                                self.logger.add_scalar(
                                    'pretrain/' + key + '/' + str(idx), v,
                                    epoch)
                        else:
                            self.logger.add_scalar('pretrain/' + key, value,
                                                   epoch)
                    self.logger.add_scalar('pretrain/loss', pre_loss, epoch)
                    self.logger.flush()
                    self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f' %
                                  (epoch, pre_loss))

                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info(
                    '>>> Stop by pre signal, skip to adversarial training...')
                break

    def adv_train_generator(self, g_step):
        total_loss = 0
        for step in range(g_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0

    def adv_train_discriminator(self, d_step):
        total_loss = 0
        for step in range(d_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)
            if self.norm == 'gp':
                gp_loss = cacl_gradient_penalty(self.dis, real_samples,
                                                gen_samples)
                d_loss += gp_loss * 10

            self.optimize(self.dis_opt, d_loss, self.dis)
            total_loss += d_loss.item()

        return total_loss / d_step if d_step != 0 else 0

    def update_temperature(self, i, N):
        self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N,
                                                     cfg.temp_adpt)

    @staticmethod
    def optimize(opt, loss, model=None, retain_graph=False):
        opt.zero_grad()
        loss.backward(retain_graph=retain_graph)
        if model is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm)
        opt.step()