示例#1
0
 def train_g(self,
             epoch_num,
             mode='Adam',
             dataname='MNIST',
             logname='MNIST'):
     print(mode)
     if mode == 'SGD':
         g_optimizer = optim.SGD(self.G.parameters(),
                                 lr=self.lr,
                                 weight_decay=self.weight_decay)
         self.writer_init(logname=logname,
                          comments='SGD-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     elif mode == 'Adam':
         g_optimizer = optim.Adam(self.G.parameters(),
                                  lr=self.lr,
                                  weight_decay=self.weight_decay,
                                  betas=(0.5, 0.999))
         self.writer_init(logname=logname,
                          comments='ADAM-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     elif mode == 'RMSProp':
         g_optimizer = RMSprop(self.G.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
         self.writer_init(logname=logname,
                          comments='RMSProp-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     timer = time.time()
     for e in range(epoch_num):
         z = torch.randn((self.batchsize, self.z_dim),
                         device=self.device)  ## changed
         fake_x = self.G(z)
         d_fake = self.D(fake_x)
         # G_loss = g_loss(d_fake)
         G_loss = self.criterion(
             d_fake, torch.ones(d_fake.shape, device=self.device))
         g_optimizer.zero_grad()
         zero_grad(self.D.parameters())
         G_loss.backward()
         g_optimizer.step()
         gd = torch.norm(torch.cat(
             [p.grad.contiguous().view(-1) for p in self.D.parameters()]),
                         p=2)
         gg = torch.norm(torch.cat(
             [p.grad.contiguous().view(-1) for p in self.G.parameters()]),
                         p=2)
         self.plot_param(G_loss=G_loss)
         self.plot_grad(gd=gd, gg=gg)
         if self.count % self.show_iter == 0:
             self.show_info(timer=time.time() - timer, D_loss=G_loss)
             timer = time.time()
         self.count += 1
         if self.count % 5000 == 0:
             self.save_checkpoint('fixD_%s-%.5f_%d.pth' %
                                  (mode, self.lr, self.count),
                                  dataset=dataname)
     self.writer.close()
 def zero_grad(self):
     zero_grad(self.max_params)
     zero_grad(self.min_params)
 def zero_grad(self):
     zero_grad(self.max_params.parameters())
     zero_grad(self.min_params.parameters())