def train_g(self, epoch_num, mode='Adam', dataname='MNIST', logname='MNIST'): print(mode) if mode == 'SGD': g_optimizer = optim.SGD(self.G.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.writer_init(logname=logname, comments='SGD-%.3f_%.5f' % (self.lr, self.weight_decay)) elif mode == 'Adam': g_optimizer = optim.Adam(self.G.parameters(), lr=self.lr, weight_decay=self.weight_decay, betas=(0.5, 0.999)) self.writer_init(logname=logname, comments='ADAM-%.3f_%.5f' % (self.lr, self.weight_decay)) elif mode == 'RMSProp': g_optimizer = RMSprop(self.G.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.writer_init(logname=logname, comments='RMSProp-%.3f_%.5f' % (self.lr, self.weight_decay)) timer = time.time() for e in range(epoch_num): z = torch.randn((self.batchsize, self.z_dim), device=self.device) ## changed fake_x = self.G(z) d_fake = self.D(fake_x) # G_loss = g_loss(d_fake) G_loss = self.criterion( d_fake, torch.ones(d_fake.shape, device=self.device)) g_optimizer.zero_grad() zero_grad(self.D.parameters()) G_loss.backward() g_optimizer.step() gd = torch.norm(torch.cat( [p.grad.contiguous().view(-1) for p in self.D.parameters()]), p=2) gg = torch.norm(torch.cat( [p.grad.contiguous().view(-1) for p in self.G.parameters()]), p=2) self.plot_param(G_loss=G_loss) self.plot_grad(gd=gd, gg=gg) if self.count % self.show_iter == 0: self.show_info(timer=time.time() - timer, D_loss=G_loss) timer = time.time() self.count += 1 if self.count % 5000 == 0: self.save_checkpoint('fixD_%s-%.5f_%d.pth' % (mode, self.lr, self.count), dataset=dataname) self.writer.close()
def zero_grad(self): zero_grad(self.max_params) zero_grad(self.min_params)
def train_d(self, epoch_num, mode='Adam', dataname='MNIST', logname='MNIST', overtrain_path=None, compare_weight=None, his_flag=False, info_time=100, optim_state=None): path = None if overtrain_path is not None: path = overtrain_path if compare_weight is not None: path = compare_weight if path is not None: discriminator = dc_D().to(self.device) model_weight = torch.load(path) discriminator.load_state_dict(model_weight['D']) model_vec = torch.cat( [p.contiguous().view(-1) for p in discriminator.parameters()]) print('Load discriminator from %s' % path) print(mode) if mode == 'SGD': d_optimizer = optim.SGD(self.D.parameters(), lr=self.lr_d, weight_decay=self.weight_decay) self.writer_init(logname=logname, comments='SGD-%.3f_%.3f' % (self.lr_d, self.weight_decay)) elif mode == 'Adam': d_optimizer = Adam(self.D.parameters(), lr=self.lr_d, weight_decay=self.weight_decay, betas=(0.5, 0.999)) self.writer_init(logname=logname, comments='ADAM-%.3f_%.5f' % (self.lr_d, self.weight_decay)) elif mode == 'RMSProp': d_optimizer = RMSprop(self.D.parameters(), lr=self.lr_d, weight_decay=self.weight_decay) self.writer_init(logname=logname, comments='RMSProp-%.3f_%.5f' % (self.lr_d, self.weight_decay)) if optim_state is not None: chk = torch.load(optim_state) d_optimizer.load_state_dict(chk['D_optim']) print('load optimizer state') timer = time.time() d_losses = [] g_losses = [] flag = False for e in range(epoch_num): tol_correct = 0 tol_loss = 0 tol_gloss = 0 for real_x in self.dataloader: real_x = real_x[0].to(self.device) d_real = self.D(real_x) z = torch.randn((real_x.shape[0], self.z_dim), device=self.device) ## changed (shape) fake_x = self.G(z) d_fake = self.D(fake_x) # D_loss = gan_loss(d_real, d_fake) D_loss = self.criterion(d_real, torch.ones(d_real.shape, device=self.device)) + \ self.criterion(d_fake, torch.zeros(d_fake.shape, device=self.device)) tol_loss += D_loss.item() * real_x.shape[0] G_loss = self.criterion( d_fake, torch.ones(d_fake.shape, device=self.device)).detach_() tol_gloss += G_loss.item() * fake_x.shape[0] if self.d_penalty != 0: D_loss += self.l2penalty() if overtrain_path is not None and self.count % 2 == 0: self.plot_proj(epoch=self.count, model_vec=model_vec, loss=D_loss) if compare_weight is not None and self.count % info_time == 0: self.plot_diff(model_vec=model_vec) self.plot_param(D_loss=D_loss, G_loss=G_loss) d_optimizer.zero_grad() zero_grad(self.G.parameters()) D_loss.backward() # flag = True if self.count % info_time == 0 else False if e != 0: d_optimizer.step() # d_steps, d_updates = d_optimizer.step(info=flag) # if flag: # self.plot_optim(d_steps=d_steps, d_updates=d_updates, # his=his_flag) tol_correct += (d_real > 0).sum().item() + (d_fake < 0).sum().item() gd = torch.norm(torch.cat([ p.grad.contiguous().view(-1) for p in self.D.parameters() ]), p=2) gg = torch.norm(torch.cat([ p.grad.contiguous().view(-1) for p in self.G.parameters() ]), p=2) self.plot_grad(gd=gd, gg=gg) self.plot_d(d_real, d_fake) if self.count % self.show_iter == 0: self.show_info(timer=time.time() - timer, D_loss=D_loss, G_loss=G_loss, logdir=logname) timer = time.time() self.count += 1 tol_loss /= len(self.dataset) tol_gloss /= len(self.dataset) d_losses.append(tol_loss) g_losses.append(tol_gloss) acc = 50.0 * tol_correct / len(self.dataset) self.writer.add_scalar('Train/D_Loss', tol_loss, global_step=e) self.writer.add_scalar('Train/G_Loss', tol_gloss, global_step=e) self.writer.add_scalar('Train/Accuracy', acc, global_step=e) print('Epoch :{}/{}, Acc: {}/{}: {:.3f}%, ' 'D Loss mean: {:.4f}, G Loss mean: {:.4f}'.format( e, epoch_num, tol_correct, 2 * len(self.dataset), acc, tol_loss, tol_gloss)) self.save_checkpoint('fixG_%s-%.5f_%d.pth' % (mode, self.lr_d, e), dataset=dataname) self.writer.close() return d_losses, g_losses