def train_g(self,
             epoch_num,
             mode='Adam',
             dataname='MNIST',
             logname='MNIST'):
     print(mode)
     if mode == 'SGD':
         g_optimizer = optim.SGD(self.G.parameters(),
                                 lr=self.lr,
                                 weight_decay=self.weight_decay)
         self.writer_init(logname=logname,
                          comments='SGD-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     elif mode == 'Adam':
         g_optimizer = optim.Adam(self.G.parameters(),
                                  lr=self.lr,
                                  weight_decay=self.weight_decay,
                                  betas=(0.5, 0.999))
         self.writer_init(logname=logname,
                          comments='ADAM-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     elif mode == 'RMSProp':
         g_optimizer = RMSprop(self.G.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
         self.writer_init(logname=logname,
                          comments='RMSProp-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     timer = time.time()
     for e in range(epoch_num):
         z = torch.randn((self.batchsize, self.z_dim),
                         device=self.device)  ## changed
         fake_x = self.G(z)
         d_fake = self.D(fake_x)
         # G_loss = g_loss(d_fake)
         G_loss = self.criterion(
             d_fake, torch.ones(d_fake.shape, device=self.device))
         g_optimizer.zero_grad()
         zero_grad(self.D.parameters())
         G_loss.backward()
         g_optimizer.step()
         gd = torch.norm(torch.cat(
             [p.grad.contiguous().view(-1) for p in self.D.parameters()]),
                         p=2)
         gg = torch.norm(torch.cat(
             [p.grad.contiguous().view(-1) for p in self.G.parameters()]),
                         p=2)
         self.plot_param(G_loss=G_loss)
         self.plot_grad(gd=gd, gg=gg)
         if self.count % self.show_iter == 0:
             self.show_info(timer=time.time() - timer, D_loss=G_loss)
             timer = time.time()
         self.count += 1
         if self.count % 5000 == 0:
             self.save_checkpoint('fixD_%s-%.5f_%d.pth' %
                                  (mode, self.lr, self.count),
                                  dataset=dataname)
     self.writer.close()
Beispiel #2
0
 def zero_grad(self):
     zero_grad(self.max_params)
     zero_grad(self.min_params)
Beispiel #3
0
    def train_d(self,
                epoch_num,
                mode='Adam',
                dataname='MNIST',
                logname='MNIST',
                overtrain_path=None,
                compare_weight=None,
                his_flag=False,
                info_time=100,
                optim_state=None):
        path = None
        if overtrain_path is not None:
            path = overtrain_path
        if compare_weight is not None:
            path = compare_weight
        if path is not None:
            discriminator = dc_D().to(self.device)
            model_weight = torch.load(path)
            discriminator.load_state_dict(model_weight['D'])
            model_vec = torch.cat(
                [p.contiguous().view(-1) for p in discriminator.parameters()])
            print('Load discriminator from %s' % path)

        print(mode)
        if mode == 'SGD':
            d_optimizer = optim.SGD(self.D.parameters(),
                                    lr=self.lr_d,
                                    weight_decay=self.weight_decay)
            self.writer_init(logname=logname,
                             comments='SGD-%.3f_%.3f' %
                             (self.lr_d, self.weight_decay))
        elif mode == 'Adam':
            d_optimizer = Adam(self.D.parameters(),
                               lr=self.lr_d,
                               weight_decay=self.weight_decay,
                               betas=(0.5, 0.999))
            self.writer_init(logname=logname,
                             comments='ADAM-%.3f_%.5f' %
                             (self.lr_d, self.weight_decay))
        elif mode == 'RMSProp':
            d_optimizer = RMSprop(self.D.parameters(),
                                  lr=self.lr_d,
                                  weight_decay=self.weight_decay)
            self.writer_init(logname=logname,
                             comments='RMSProp-%.3f_%.5f' %
                             (self.lr_d, self.weight_decay))

        if optim_state is not None:
            chk = torch.load(optim_state)
            d_optimizer.load_state_dict(chk['D_optim'])
            print('load optimizer state')
        timer = time.time()
        d_losses = []
        g_losses = []
        flag = False
        for e in range(epoch_num):
            tol_correct = 0
            tol_loss = 0
            tol_gloss = 0
            for real_x in self.dataloader:
                real_x = real_x[0].to(self.device)
                d_real = self.D(real_x)

                z = torch.randn((real_x.shape[0], self.z_dim),
                                device=self.device)  ## changed (shape)
                fake_x = self.G(z)
                d_fake = self.D(fake_x)

                # D_loss = gan_loss(d_real, d_fake)
                D_loss = self.criterion(d_real, torch.ones(d_real.shape, device=self.device)) + \
                         self.criterion(d_fake, torch.zeros(d_fake.shape, device=self.device))
                tol_loss += D_loss.item() * real_x.shape[0]
                G_loss = self.criterion(
                    d_fake, torch.ones(d_fake.shape,
                                       device=self.device)).detach_()
                tol_gloss += G_loss.item() * fake_x.shape[0]
                if self.d_penalty != 0:
                    D_loss += self.l2penalty()
                if overtrain_path is not None and self.count % 2 == 0:
                    self.plot_proj(epoch=self.count,
                                   model_vec=model_vec,
                                   loss=D_loss)
                if compare_weight is not None and self.count % info_time == 0:
                    self.plot_diff(model_vec=model_vec)
                self.plot_param(D_loss=D_loss, G_loss=G_loss)
                d_optimizer.zero_grad()
                zero_grad(self.G.parameters())
                D_loss.backward()
                # flag = True if self.count % info_time == 0 else False
                if e != 0:
                    d_optimizer.step()
                # d_steps, d_updates = d_optimizer.step(info=flag)
                # if flag:
                #     self.plot_optim(d_steps=d_steps, d_updates=d_updates,
                #                     his=his_flag)

                tol_correct += (d_real > 0).sum().item() + (d_fake <
                                                            0).sum().item()

                gd = torch.norm(torch.cat([
                    p.grad.contiguous().view(-1) for p in self.D.parameters()
                ]),
                                p=2)
                gg = torch.norm(torch.cat([
                    p.grad.contiguous().view(-1) for p in self.G.parameters()
                ]),
                                p=2)

                self.plot_grad(gd=gd, gg=gg)
                self.plot_d(d_real, d_fake)
                if self.count % self.show_iter == 0:
                    self.show_info(timer=time.time() - timer,
                                   D_loss=D_loss,
                                   G_loss=G_loss,
                                   logdir=logname)
                    timer = time.time()
                self.count += 1
            tol_loss /= len(self.dataset)
            tol_gloss /= len(self.dataset)
            d_losses.append(tol_loss)
            g_losses.append(tol_gloss)
            acc = 50.0 * tol_correct / len(self.dataset)

            self.writer.add_scalar('Train/D_Loss', tol_loss, global_step=e)
            self.writer.add_scalar('Train/G_Loss', tol_gloss, global_step=e)
            self.writer.add_scalar('Train/Accuracy', acc, global_step=e)
            print('Epoch :{}/{}, Acc: {}/{}: {:.3f}%, '
                  'D Loss mean: {:.4f}, G Loss mean: {:.4f}'.format(
                      e, epoch_num, tol_correct, 2 * len(self.dataset), acc,
                      tol_loss, tol_gloss))
            self.save_checkpoint('fixG_%s-%.5f_%d.pth' % (mode, self.lr_d, e),
                                 dataset=dataname)
        self.writer.close()
        return d_losses, g_losses