def train(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform = transforms.Compose( [crop(args.scale, args.patch_size), augmentation()]) dataset = mydata(GT_path=args.GT_path, LR_path=args.LR_path, in_memory=args.in_memory, transform=transform) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=args.res_num, scale=args.scale) if args.fine_tuning: generator.load_state_dict(torch.load(args.generator_path)) print("pre-trained model is loaded") print("path : %s" % (args.generator_path)) generator = generator.to(device) generator.train() l2_loss = nn.MSELoss() g_optim = optim.Adam(generator.parameters(), lr=1e-4) pre_epoch = 0 fine_epoch = 0 #### Train using L2_loss while pre_epoch < args.pre_train_epoch: for i, tr_data in enumerate(loader): gt = tr_data['GT'].to(device) lr = tr_data['LR'].to(device) output, _ = generator(lr) loss = l2_loss(gt, output) g_optim.zero_grad() loss.backward() g_optim.step() pre_epoch += 1 if pre_epoch % 2 == 0: print(pre_epoch) print(loss.item()) print('=========') if pre_epoch % 800 == 0: torch.save(generator.state_dict(), './model/pre_trained_model_%03d.pt' % pre_epoch) #### Train using perceptual & adversarial loss vgg_net = vgg19().to(device) vgg_net = vgg_net.eval() discriminator = Discriminator(patch_size=args.patch_size * args.scale) discriminator = discriminator.to(device) discriminator.train() d_optim = optim.Adam(discriminator.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.StepLR(g_optim, step_size=2000, gamma=0.1) VGG_loss = perceptual_loss(vgg_net) cross_ent = nn.BCELoss() tv_loss = TVLoss() real_label = torch.ones((args.batch_size, 1)).to(device) fake_label = torch.zeros((args.batch_size, 1)).to(device) while fine_epoch < args.fine_train_epoch: scheduler.step() for i, tr_data in enumerate(loader): gt = tr_data['GT'].to(device) lr = tr_data['LR'].to(device) ## Training Discriminator output, _ = generator(lr) fake_prob = discriminator(output) real_prob = discriminator(gt) d_loss_real = cross_ent(real_prob, real_label) d_loss_fake = cross_ent(fake_prob, fake_label) d_loss = d_loss_real + d_loss_fake g_optim.zero_grad() d_optim.zero_grad() d_loss.backward() d_optim.step() ## Training Generator output, _ = generator(lr) fake_prob = discriminator(output) _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer=args.feat_layer) L2_loss = l2_loss(output, gt) percep_loss = args.vgg_rescale_coeff * _percep_loss adversarial_loss = args.adv_coeff * cross_ent( fake_prob, real_label) total_variance_loss = args.tv_loss_coeff * tv_loss( args.vgg_rescale_coeff * (hr_feat - sr_feat)**2) g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss g_optim.zero_grad() d_optim.zero_grad() g_loss.backward() g_optim.step() fine_epoch += 1 if fine_epoch % 2 == 0: print(fine_epoch) print(g_loss.item()) print(d_loss.item()) print('=========') if fine_epoch % 500 == 0: torch.save(generator.state_dict(), './model/SRGAN_gene_%03d.pt' % fine_epoch) torch.save(discriminator.state_dict(), './model/SRGAN_discrim_%03d.pt' % fine_epoch)
class SRGANTrainer(object): def __init__(self, config, training_loader, testing_loader): super(SRGANTrainer, self).__init__() self.GPU_IN_USE = torch.cuda.is_available() self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu') self.netG = None self.netD = None self.lr = config.lr self.nEpochs = config.nEpochs self.epoch_pretrain = 10 self.criterionG = None self.criterionD = None self.optimizerG = None self.optimizerD = None self.feature_extractor = None self.scheduler = None self.seed = config.seed self.upscale_factor = config.upscale_factor self.num_residuals = 16 self.training_loader = training_loader self.testing_loader = testing_loader def build_model(self): self.netG = Generator(n_residual_blocks=self.num_residuals, upsample_factor=self.upscale_factor, base_filter=64, num_channel=1).to(self.device) self.netD = Discriminator(base_filter=64, num_channel=1).to(self.device) self.feature_extractor = vgg16(pretrained=True) self.netG.weight_init(mean=0.0, std=0.2) self.netD.weight_init(mean=0.0, std=0.2) self.criterionG = nn.MSELoss() self.criterionD = nn.BCELoss() torch.manual_seed(self.seed) if self.GPU_IN_USE: torch.cuda.manual_seed(self.seed) self.feature_extractor.cuda() cudnn.benchmark = True self.criterionG.cuda() self.criterionD.cuda() self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.9, 0.999)) self.optimizerD = optim.SGD(self.netD.parameters(), lr=self.lr / 100, momentum=0.9, nesterov=True) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizerG, milestones=[50, 75, 100], gamma=0.5) # lr decay self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizerD, milestones=[50, 75, 100], gamma=0.5) # lr decay @staticmethod def to_data(x): if torch.cuda.is_available(): x = x.cpu() return x.data def save(self): g_model_out_path = "SRGAN_Generator_model_path.pth" d_model_out_path = "SRGAN_Discriminator_model_path.pth" torch.save(self.netG, g_model_out_path) torch.save(self.netD, d_model_out_path) print("Checkpoint saved to {}".format(g_model_out_path)) print("Checkpoint saved to {}".format(d_model_out_path)) def pretrain(self): self.netG.train() print("self.netG.train done") for batch_num, (data, target) in enumerate(self.training_loader): data, target = data.to(self.device), target.to(self.device) self.netG.zero_grad() loss = self.criterionG(self.netG(data), target) loss.backward() self.optimizerG.step() def train(self): # models setup self.netG.train() self.netD.train() g_train_loss = 0 d_train_loss = 0 for batch_num, (data, target) in enumerate(self.training_loader): # setup noise real_label = torch.ones(data.size(0), data.size(1)).to(self.device) fake_label = torch.zeros(data.size(0), data.size(1)).to(self.device) data, target = data.to(self.device), target.to(self.device) # Train Discriminator self.optimizerD.zero_grad() d_real = self.netD(target) d_real_loss = self.criterionD(d_real, real_label) d_fake = self.netD(self.netG(data)) d_fake_loss = self.criterionD(d_fake, fake_label) d_total = d_real_loss + d_fake_loss d_train_loss += d_total.item() d_total.backward() self.optimizerD.step() # Train generator self.optimizerG.zero_grad() g_real = self.netG(data) g_fake = self.netD(g_real) gan_loss = self.criterionD(g_fake, real_label) mse_loss = self.criterionG(g_real, target) g_total = mse_loss + 1e-3 * gan_loss g_train_loss += g_total.item() g_total.backward() self.optimizerG.step() print(" Average G_Loss: {:.4f}".format(g_train_loss / len(self.training_loader))) def test(self): self.netG.eval() avg_psnr = 0 with torch.no_grad(): for batch_num, (data, target) in enumerate(self.testing_loader): data, target = data.to(self.device), target.to(self.device) prediction = self.netG(data) mse = self.criterionG(prediction, target) psnr = 10 * log10(1 / mse.item()) avg_psnr += psnr print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader))) def run(self): self.build_model() for epoch in range(1, self.epoch_pretrain + 1): print("pretrain epoc {} ".format(epoch)) self.pretrain() print("{}/{} pretrained".format(epoch, self.epoch_pretrain)) for epoch in range(1, self.nEpochs + 1): print("\n===> Epoch {} starts:".format(epoch)) self.train() self.test() self.scheduler.step(epoch) if epoch == self.nEpochs: self.save()
class SRGANTrainer(object): def __init__(self, config, training_loader, testing_loader, class_name): super(SRGANTrainer, self).__init__() os.environ["CUDA_VISIBLE_DEVICES"] = "0" self.GPU_IN_USE = torch.cuda.is_available() self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu') self.net_G = None self.net_D = None self.lr = config.lr self.num_epoch = config.num_epoch self.epoch_pretrain = 10 self.loss_G = None self.loss_D = None self.optimizer_G = None self.optimizer_D = None self.feature_extractor = None self.scheduler = None self.seed = config.seed self.upscale_factor = config.upscale_factor self.num_residuals = 16 self.training_loader = training_loader self.testing_loader = testing_loader self.g_model_out_path = "SRGAN_Generator_model_" + class_name self.d_model_out_path = "SRGAN_Discriminator_model_" + class_name self.loss_set = [] self.psnr_set = [] self.mse_set = [] self.class_name = class_name self.num_input = 1 if self.class_name != 'velocity' else 3 def build_model(self): self.net_G = Generator(num_residual=self.num_residuals, upscale_factor=self.upscale_factor, base_filter=128, num_input=self.num_input).to(self.device) self.net_D = Discriminator(base_filter=128, num_input=self.num_input).to(self.device) #self.feature_extractor = vgg16(pretrained=True) self.net_G.weight_init(mean=0.0, std=0.2) self.net_D.weight_init(mean=0.0, std=0.2) self.loss_G = nn.MSELoss() self.loss_D = nn.BCELoss() torch.manual_seed(self.seed) if self.GPU_IN_USE: torch.cuda.manual_seed(self.seed) #self.feature_extractor.cuda() cudnn.benchmark = True self.loss_G.cuda() self.loss_D.cuda() self.optimizer_G = optim.Adam(self.net_G.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=1e-8) self.optimizer_D = optim.SGD(self.net_D.parameters(), lr=self.lr / 100, momentum=0.9, nesterov=True) ''' self.optimizer_D = optim.Adam(self.net_D.parameters(), lr=self.lr / 100, betas=(0.9, 0.999), weight_decay=1e-8) ''' self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer_G, milestones=[20, 40, 60, 80, 100], gamma=0.5) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer_D, milestones=[20, 40, 60, 80, 100], gamma=0.5) @staticmethod def to_data(x): if torch.cuda.is_available(): x = x.cpu() return x.data def save(self): torch.save(self.net_G, self.g_model_out_path) torch.save(self.net_D, self.d_model_out_path) print("Checkpoint saved to {}".format(self.g_model_out_path)) print("Checkpoint saved to {}".format(self.d_model_out_path)) def save_loss(self, cnt): np.save("loss_set_" + self.class_name + str(cnt), np.array(self.loss_set)) np.save("psnr_set_" + self.class_name + str(cnt), np.array(self.psnr_set)) np.save("mse_set_" + self.class_name + str(cnt), np.array(self.mse_set)) def load(self): self.net_G = torch.load(self.g_model_out_path) self.net_D = torch.load(self.d_model_out_path) def pretrain(self): self.net_G.train() for batch_num, (data, target) in enumerate(self.training_loader): print("batch_num: ", batch_num, "/", len(self.training_loader) - 1) data, target = data.to(self.device), target.to(self.device) #data, target = Variable(data).cuda(), Variable(target).cuda() #print(data[0][0].shape) #print(data[1][0].shape) #print(data[2][0].shape) #print(data[3][0].shape) self.net_G.zero_grad() gen = self.net_G(data) loss = self.loss_G(gen, target.float()) loss.backward() self.optimizer_G.step() torch.cuda.empty_cache() def train(self): # models setup self.net_G.train() self.net_D.train() g_train_loss = 0 d_train_loss = 0 torch.cuda.empty_cache() for batch_num, (data, target) in enumerate(self.training_loader): real_label = torch.ones(data.size(0), data.size(1)).to(self.device) fake_label = torch.zeros(data.size(0), data.size(1)).to(self.device) data, target = data.to(self.device), target.to(self.device) # Discriminator self.optimizer_D.zero_grad() d_real = self.net_D(target.float()) d_real_loss = self.loss_D(d_real, real_label) d_fake = self.net_D(self.net_G(data)) d_fake_loss = self.loss_D(d_fake, fake_label) print(d_real_loss, d_fake_loss) d_total = d_real_loss + d_fake_loss d_train_loss += d_total.item() d_total.backward() self.optimizer_D.step() # Generator self.optimizer_G.zero_grad() g_real = self.net_G(data) print(g_real.shape) print(target.shape) g_fake = self.net_D(g_real) gan_loss = self.loss_D(g_fake, real_label) mse_loss = self.loss_G(g_real, target.float()) print(mse_loss, gan_loss) g_total = mse_loss + 0.001 * gan_loss #g_total = gan_loss g_train_loss += g_total.item() g_total.backward() self.optimizer_G.step() #progress_bar(batch_num, len(self.training_loader), 'G_Loss: %.4f | D_Loss: %.4f' % (g_train_loss / (batch_num + 1), d_train_loss / (batch_num + 1))) average_loss = g_train_loss / len(self.training_loader) print(" Average G_Loss: {:.8f}".format(average_loss)) return average_loss def test(self): self.net_G.eval() avg_psnr = 0 avg_mse_loss = 0 with torch.no_grad(): for batch_num, (data, target) in enumerate(self.testing_loader): data, target = data.to(self.device), target.to(self.device) prediction = self.net_G(data) mse = self.loss_G(prediction, target.float()) psnr = 10 * log10(1 / mse.item()) avg_mse_loss += mse.item() avg_psnr += psnr #progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1))) average_psnr = avg_psnr / len(self.testing_loader) average_mse = avg_mse_loss / len(self.testing_loader) print(" Average MSE Loss: {:.8f}".format(average_mse)) print(" Average PSNR: {:.8f} dB".format(average_psnr)) return average_psnr, average_mse def run(self): self.build_model() ''' if (self.class_name != 'velocity'): for epoch in range(1, self.epoch_pretrain + 1): self.pretrain() print("{}/{} pretrained".format(epoch, self.epoch_pretrain)) ''' for epoch in range(1, self.num_epoch + 1): print("\n===> Epoch {} starts:".format(epoch)) loss = self.train() psnr, mse_loss = self.test() self.scheduler.step(epoch) self.loss_set.append(loss) self.psnr_set.append(psnr) self.mse_set.append(mse_loss) if epoch % 10 == 0: self.save_loss(epoch) self.save() if epoch == self.num_epoch: self.save() def restore(self): self.build_model() self.load()
'g_score': [], 'psnr': [], 'ssim': [] } for epoch in range(1, num_epochs + 1): running_results = { 'batch_sizes': 0, 'd_loss': 0.0, 'g_loss': 0.0, 'd_score': 0.0, 'g_score': 0.0 } G.train() D.train() for data, target in train_loader: g_update_first = True batch_size = data.size(0) running_results['batch_sizes'] += batch_size # update D: maximize d loss real_img = Variable(target).cuda() z = Variable(data) if torch.cuda.is_available(): z = z.cuda() fake_img = G(z) D.zero_grad() real_out = D(real_img).mean() fake_out = D(fake_img).mean()