def optimize_strategy(self, img, gt): gen_imgs = self.model(img) mse_loss = get_mse_loss_function() loss = mse_loss(gen_imgs, gt) # BP self.optimizer.zero_grad() loss.backward() self.optimizer.step() return gen_imgs, loss
def optimize_strategy(self,img,gt): gen_imgs = self.model(img) mse_loss = get_mse_loss_function() loss = mse_loss(gen_imgs,gt) #BP self.optimizer.zero_grad() loss.backward() #jinxing tidu caijian yinwei lossNaN nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=20, norm_type=2) self.optimizer.step() return gen_imgs,loss
def optimize_multi_strategy(self, img_0, gt_0, img_1, gt_1, img_2, gt_2): gen_imgs = self.model(img_0, img_1, img_2) mse_loss = get_mse_loss_function() loss_0 = mse_loss(gen_imgs[2], gt_0) #256 loss_1 = mse_loss(gen_imgs[1], gt_1) #128 loss_2 = mse_loss(gen_imgs[0], gt_2) #64 loss = loss_0 + loss_1 + loss_2 #BP self.optimizer.zero_grad() loss.backward() self.optimizer.step() return gen_imgs, loss
def train(self, args, train_dataloader, test_dataloader, start_epoch, end_epoch): patch = (1, args.img_height // (2**args.n_D_layers * 4), args.img_width // (2**args.n_D_layers * 4)) fake_img_buffer = ReplayBuffer() fake_gt_buffer = ReplayBuffer() writer = SummaryWriter(log_dir='{}'.format(args.log_dir), comment='train_loss') print("======== begin train model ========") print('data_size:', len(train_dataloader)) best_loss = 3 os.makedirs(args.results_gen_model_dir, exist_ok=True) os.makedirs(args.results_dis_model_dir, exist_ok=True) os.makedirs(args.results_img_dir, exist_ok=True) os.makedirs(args.results_gt_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True) for epoch in range(start_epoch, end_epoch): for i, batch in enumerate(train_dataloader): if torch.cuda.is_available(): img = Variable(batch['X'].type(torch.FloatTensor).cuda()) gt = Variable(batch['Y'].type(torch.FloatTensor).cuda()) else: img = Variable(batch['X'].type(torch.FloatTensor)) gt = Variable(batch['Y'].type(torch.FloatTensor)) valid = Variable(torch.FloatTensor( np.ones((img.size(0), *patch))).cuda(), requires_grad=False) fake = Variable(torch.FloatTensor( np.zeros((img.size(0), *patch))).cuda(), requires_grad=False) ##### Train Generator ####### self.optimizer_G.zero_grad() # identity loss identity_loss = get_l1_loss_function() loss_id_img = identity_loss(self.gen_BA(img), img) loss_id_gt = identity_loss(self.gen_AB(gt), gt) loss_identity = (loss_id_gt + loss_id_img) / 2 # GAN loss fake_gt = self.gen_AB(img) pred_fake = self.dis_B(fake_gt) gan_loss = get_mse_loss_function() loss_GAN_img_gt = gan_loss(pred_fake, valid) fake_img = self.gen_BA(gt) pred_fake = self.dis_B(fake_img) loss_GAN_gt_img = gan_loss(pred_fake, valid) loss_GAN = (loss_GAN_img_gt + loss_GAN_gt_img) / 2 #Cycle loss recov_img = self.gen_BA(fake_gt) cycle_loss = get_l1_loss_function() loss_cycle_img = cycle_loss(recov_img, img) recov_gt = self.gen_BA(fake_img) loss_cycle_gt = cycle_loss(recov_gt, gt) loss_cycle = (loss_cycle_gt + loss_cycle_img) / 2 # Tota loss loss_G = loss_GAN + args.lambda_id * loss_identity + args.lambda_cyc * loss_cycle loss_G.backward() self.optimizer_G.step() batches_done = epoch * len(train_dataloader) + i ####### Train Discriminator A ####### self.optimizer_D_A.zero_grad() pred_real = self.dis_A(img) loss_real = gan_loss(pred_real, valid) fake_img = fake_img_buffer.push_and_pop(fake_img) pred_fake = self.dis_A(fake_img.detach()) loss_fake = gan_loss(pred_fake, fake) loss_D_img = (loss_real + loss_fake) / 2 loss_D_img.backward() self.optimizer_D_A.step() ####### Train Discriminator B ####### self.optimizer_D_B.zero_grad() pred_real = self.dis_B(gt) loss_real = gan_loss(pred_real, valid) fake_gt = fake_gt_buffer.push_and_pop(fake_gt) pred_fake = self.dis_B(fake_gt.detach()) loss_fake = gan_loss(pred_fake, fake) loss_D_gt = (loss_real + loss_fake) / 2 loss_D_gt.backward() self.optimizer_D_B.step() loss_D = (loss_D_img + loss_D_gt) / 2 writer.add_scalars('{}_train_loss'.format(args.model.arch), { 'loss_G': loss_G.data.cpu(), 'loss_D': loss_D.data.cpu() }, batches_done) f = open(os.path.join(args.log_dir, 'log.txt'), 'a+') info = 'epoch:' + str(epoch) + ' batches_done:' + str(batches_done) + ' loss_GAN:' + str(loss_GAN.data.cpu())\ + ' loss_identity:' + str(loss_identity.data.cpu())+' loss_identity:'+ str(loss_identity) + ' loss_cycle:'\ + str(loss_cycle.data.cpu()) + ' loss_G:' + str(loss_G.data.cpu())+ ' loss_D_gt:' + str(loss_D_gt.data.cpu()) + ' loss_D_img:' + str(loss_D_img.data.cpu()) f.write(info + '\n') ########## save best result ############## if loss_G.data.cpu() < best_loss: best_loss = loss_G.data.cpu() torch.save( self.gen_AB.state_dict(), args.results_gen_model_dir + '/%d-%d_gen_AB_best_model.pkl' % (epoch, batches_done)) torch.save( self.gen_BA.state_dict(), args.results_gen_model_dir + '/%d-%d_gen_BA_best_model.pkl' % (epoch, batches_done)) torch.save( self.dis_A.state_dict(), args.results_dis_model_dir + '/%d-%d_dis_A_best_model.pkl' % (epoch, batches_done)) torch.save( self.dis_B.state_dict(), args.results_dis_model_dir + '/%d-%d_dis_B_best_model.pkl' % (epoch, batches_done)) save_image(fake_gt, '%s/%s-%s.bmp' % (args.results_gt_dir, epoch, batches_done), nrow=4, normalize=True) save_image(fake_img, '%s/%s-%s.bmp' % (args.results_img_dir, epoch, batches_done), nrow=4, normalize=True) if i % args.interval == 0: print('[epoch %d/%d] [batch %d/%d] [loss: %f]' % (epoch, end_epoch, batches_done, (end_epoch * len(train_dataloader)), loss_G.item())) if epoch % args.interval == 0: torch.save( self.gen_AB.state_dict(), args.results_gen_model_dir + '/%d-%d_gen_AB.pkl' % (epoch, batches_done)) torch.save( self.gen_BA.state_dict(), args.results_gen_model_dir + '/%d-%d_gen_BA.pkl' % (epoch, batches_done)) torch.save( self.dis_A.state_dict(), args.results_dis_model_dir + '/%d-%d_dis_A.pkl' % (epoch, batches_done)) torch.save( self.dis_B.state_dict(), args.results_dis_model_dir + '/%d-%d_dis_B.pkl' % (epoch, batches_done)) save_image(fake_gt, '%s/%s-%s.bmp' % (args.results_gt_dir, epoch, batches_done), nrow=4, normalize=True) save_image(fake_img, '%s/%s-%s.bmp' % (args.results_img_dir, epoch, batches_done), nrow=4, normalize=True) f.close() writer.close()
def train(self,args,train_dataloader,test_dataloader,start_epoch,end_epoch): logreport = LogReport(log_dir = args.config.log_dir) testreport = TestReport(log_dir=args.config.out_dir) print("======== begin train model ========") for epoch in range(start_epoch,end_epoch): for i ,(img,gt,name) in enumerate(train_dataloader): if torch.cuda.is_available(): img = Variable(img.cuda()) gt = Variable(gt.cuda()) else: img = Variable(img) gt = Variable(gt) fake_gt = self.gen.forward(img) ######################## ########Update D######## ######################## self.D_optimizer.zero_grad() # train with fake fake_img_gt = torch.cat((img, fake_gt), 1) pred_fake = self.dis.forward(fake_img_gt.detach()) batchsize, _, w, h = pred_fake.size() loss_d_fake = torch.sum(get_softplus(-pred_fake)) / batchsize / w / h # train with real real_img_gt = torch.cat((img, gt), 1) pred_real = self.dis.forward(real_img_gt) loss_d_real = torch.sum(get_softplus(-pred_real)) / batchsize / w / h # combined loss loss_d = loss_d_fake + loss_d_real loss_d.backward() if epoch % args.minimax == 0: self.D_optimizer.step() ######################## ########Update G######## ######################## self.G_optimizer.zero_grad() # First , G(A) should fake the discriminator fake_img_gt = torch.cat((img, fake_gt), 1) pred_fake = self.dis.forward(fake_img_gt) loss_g_gan = torch.sum(get_softplus(-pred_fake)) / batchsize / w /h # Second G(A) = B loss_g_l1 = get_l1_loss_function(fake_gt,gt) * args.config.lamb loss_g = loss_g_gan + loss_g_l1 loss_g.backward() self.G_optimizer.step() # log if i % 100 == 0: print( "===> Epoch[{}]({}/{}): loss_d_fake: {:.4f} loss_d_real: {:.4f} loss_g_gan: {:.4f} loss_g_l1: {:.4f}".format( epoch, i, len(train_dataloader), loss_d_fake.item(), loss_d_real.item(), loss_g_gan.item(), loss_g_l1.item())) log = {} log['epoch'] = epoch log['iteration'] = len(train_dataloader) * (epoch - 1) + i log['gen/loss'] = loss_g.item() log['dis/loss'] = loss_d.item() logreport(log) with torch.no_grad(): log_test = test(args, test_dataloader, self.gen, get_mse_loss_function(), epoch) testreport(log_test) if epoch % args.snapshot_interval == 0: checkpoint(args, epoch, self.gen, self.dis) logreport.save_lossgraph() testreport.save_lossgraph()