def train_g(self, epoch_num, mode='Adam', dataname='MNIST', logname='MNIST'): print(mode) if mode == 'SGD': g_optimizer = optim.SGD(self.G.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.writer_init(logname=logname, comments='SGD-%.3f_%.5f' % (self.lr, self.weight_decay)) elif mode == 'Adam': g_optimizer = optim.Adam(self.G.parameters(), lr=self.lr, weight_decay=self.weight_decay, betas=(0.5, 0.999)) self.writer_init(logname=logname, comments='ADAM-%.3f_%.5f' % (self.lr, self.weight_decay)) elif mode == 'RMSProp': g_optimizer = RMSprop(self.G.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.writer_init(logname=logname, comments='RMSProp-%.3f_%.5f' % (self.lr, self.weight_decay)) timer = time.time() for e in range(epoch_num): z = torch.randn((self.batchsize, self.z_dim), device=self.device) ## changed fake_x = self.G(z) d_fake = self.D(fake_x) # G_loss = g_loss(d_fake) G_loss = self.criterion( d_fake, torch.ones(d_fake.shape, device=self.device)) g_optimizer.zero_grad() zero_grad(self.D.parameters()) G_loss.backward() g_optimizer.step() gd = torch.norm(torch.cat( [p.grad.contiguous().view(-1) for p in self.D.parameters()]), p=2) gg = torch.norm(torch.cat( [p.grad.contiguous().view(-1) for p in self.G.parameters()]), p=2) self.plot_param(G_loss=G_loss) self.plot_grad(gd=gd, gg=gg) if self.count % self.show_iter == 0: self.show_info(timer=time.time() - timer, D_loss=G_loss) timer = time.time() self.count += 1 if self.count % 5000 == 0: self.save_checkpoint('fixD_%s-%.5f_%d.pth' % (mode, self.lr, self.count), dataset=dataname) self.writer.close()
def zero_grad(self): zero_grad(self.max_params) zero_grad(self.min_params)
def zero_grad(self): zero_grad(self.max_params.parameters()) zero_grad(self.min_params.parameters())