def train_gan(): [ depth_net, color_net, d_net, depth_optimizer, color_optimizer, d_optimizer ] = load_networks(True) generator_criterion = GeneratorLoss() if param.useGPU: generator_criterion.cuda() train_system(depth_net, color_net, d_net, depth_optimizer, color_optimizer, d_optimizer, generator_criterion)
batch_size=1, shuffle=False) netG = Generator(UPSCALE_FACTOR) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) generator_criterion = GeneratorLoss() if torch.cuda.is_available(): netG.cuda() netD.cuda() generator_criterion.cuda() optimizerG = optim.Adam(netG.parameters()) optimizerD = optim.Adam(netD.parameters()) results = { 'd_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [] } for epoch in range(1, NUM_EPOCHS + 1): train_bar = tqdm(train_loader)
num_workers=0, batch_size=batch_size, shuffle=True) netG = Generator(upscale_factor) netD = Discriminator() gen_criterion = GeneratorLoss() optimizerG = optim.Adam(netG.parameters()) optimizerD = optim.Adam(netD.parameters()) if torch.cuda.is_available(): netG.cuda() netD.cuda() gen_criterion.cuda() for epoch in range(1, num_epochs + 1): netG.train() netD.train() count = 1 for lr_img, hr_img in train_loader: lr_img = Variable(lr_img) hr_img = Variable(hr_img) if torch.cuda.is_available(): lr_img = lr_img.cuda() hr_img = hr_img.cuda() sr_img = netG(lr_img)
class trainer(object): def __init__(self, cfg): self.cfg = cfg self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS, out_ch=cfg.DATASET.N_CLASS, side='out') self.Image_generator = U_Net(in_ch=3, out_ch=cfg.DATASET.N_CLASS, side='in') self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3, cfg.DATASET.IMGSIZE, patch=True) self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0], cfg.LOSS.LOSS_WEIGHT[1], cfg.LOSS.LOSS_WEIGHT[2], ignore_index=cfg.LOSS.IGNORE_INDEX) self.criterion_D = DiscriminatorLoss() train_dataset = BaseDataset(cfg, split='train') valid_dataset = BaseDataset(cfg, split='val') self.train_dataloader = data.DataLoader( train_dataset, batch_size=cfg.DATASET.BATCHSIZE, num_workers=8, shuffle=True, drop_last=True) self.valid_dataloader = data.DataLoader( valid_dataset, batch_size=cfg.DATASET.BATCHSIZE, num_workers=8, shuffle=True, drop_last=True) self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints') if not os.path.isdir(self.ckpt_outdir): os.mkdir(self.ckpt_outdir) self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val') if not os.path.isdir(self.val_outdir): os.mkdir(self.val_outdir) self.start_epoch = cfg.TRAIN.RESUME self.n_epoch = cfg.TRAIN.N_EPOCH self.optimizer_G = torch.optim.Adam( [{ 'params': self.OldLabel_generator.parameters() }, { 'params': self.Image_generator.parameters() }], lr=cfg.OPTIMIZER.G_LR, betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) self.optimizer_D = torch.optim.Adam( [{ 'params': self.discriminator.parameters(), 'initial_lr': cfg.OPTIMIZER.D_LR }], lr=cfg.OPTIMIZER.D_LR, betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2), weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE lambda_poly = lambda iters: pow( (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9) self.scheduler_G = torch.optim.lr_scheduler.LambdaLR( self.optimizer_G, lr_lambda=lambda_poly, ) # last_epoch=(self.start_epoch+1)*iter_per_epoch) self.scheduler_D = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D, lr_lambda=lambda_poly, ) # last_epoch=(self.start_epoch+1)*iter_per_epoch) self.logger = logger(cfg.TRAIN.OUTDIR, name='train') self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS) if self.start_epoch >= 0: self.OldLabel_generator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_G_N']) self.Image_generator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_G_I']) self.discriminator.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['model_D']) self.optimizer_G.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['optimizer_G']) self.optimizer_D.load_state_dict( torch.load( os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format( self.start_epoch)))['optimizer_D']) log = "Using the {}th checkpoint".format(self.start_epoch) self.logger.info(log) self.Image_generator = self.Image_generator.cuda() self.OldLabel_generator = self.OldLabel_generator.cuda() self.discriminator = self.discriminator.cuda() self.criterion_G = self.criterion_G.cuda() self.criterion_D = self.criterion_D.cuda() def train(self): all_train_iter_total_loss = [] all_train_iter_corr_loss = [] all_train_iter_recover_loss = [] all_train_iter_change_loss = [] all_train_iter_gan_loss_gen = [] all_train_iter_gan_loss_dis = [] all_val_epo_iou = [] all_val_epo_acc = [] iter_num = [0] epoch_num = [] num_batches = len(self.train_dataloader) for epoch_i in range(self.start_epoch + 1, self.n_epoch): iter_total_loss = AverageTracker() iter_corr_loss = AverageTracker() iter_recover_loss = AverageTracker() iter_change_loss = AverageTracker() iter_gan_loss_gen = AverageTracker() iter_gan_loss_dis = AverageTracker() batch_time = AverageTracker() tic = time.time() # train self.OldLabel_generator.train() self.Image_generator.train() self.discriminator.train() for i, meta in enumerate(self.train_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) # ------------------- # Train Discriminator # ------------------- self.discriminator.set_requires_grad(True) self.optimizer_D.zero_grad() fake_sample = torch.cat((image, corr_pred), 1).detach() real_sample = torch.cat( (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1) score_fake_d = self.discriminator(fake_sample) score_real = self.discriminator(real_sample) gan_loss_dis = self.criterion_D(pred_score=score_fake_d, real_score=score_real) gan_loss_dis.backward() self.optimizer_D.step() self.scheduler_D.step() # --------------- # Train Generator # --------------- self.discriminator.set_requires_grad(False) self.optimizer_G.zero_grad() score_fake = self.discriminator( torch.cat((image, corr_pred), 1)) total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G( corr_pred, recover_pred, score_fake, old_label, new_label) total_loss.backward() self.optimizer_G.step() self.scheduler_G.step() iter_total_loss.update(total_loss.item()) iter_corr_loss.update(corr_loss.item()) iter_recover_loss.update(recover_loss.item()) iter_change_loss.update(change_loss.item()) iter_gan_loss_gen.update(gan_loss_gen.item()) iter_gan_loss_dis.update(gan_loss_dis.item()) batch_time.update(time.time() - tic) tic = time.time() log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \ 'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format( datetime.now(), epoch_i, i, num_batches, batch_time.avg, total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item()) print(log) if (i + 1) % 10 == 0: all_train_iter_total_loss.append(iter_total_loss.avg) all_train_iter_corr_loss.append(iter_corr_loss.avg) all_train_iter_recover_loss.append(iter_recover_loss.avg) all_train_iter_change_loss.append(iter_change_loss.avg) all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg) all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg) iter_total_loss.reset() iter_corr_loss.reset() iter_recover_loss.reset() iter_change_loss.reset() iter_gan_loss_gen.reset() iter_gan_loss_dis.reset() vis.line(X=np.column_stack( np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)), Y=np.column_stack((all_train_iter_total_loss, all_train_iter_corr_loss, all_train_iter_recover_loss, all_train_iter_change_loss, all_train_iter_gan_loss_gen, all_train_iter_gan_loss_dis)), opts={ 'legend': [ 'total_loss', 'corr_loss', 'recover_loss', 'change_loss', 'gan_loss_gen', 'gan_loss_dis' ], 'linecolor': np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255], [255, 0, 255]]), 'title': 'Train loss of generator and discriminator' }, win='Train loss of generator and discriminator') iter_num.append(iter_num[-1] + 1) # eval self.OldLabel_generator.eval() self.Image_generator.eval() self.discriminator.eval() with torch.no_grad(): for j, meta in enumerate(self.valid_dataloader): image, old_label, new_label = meta[0].cuda(), meta[1].cuda( ), meta[2].cuda() recover_pred, feats = self.OldLabel_generator( label2onehot(old_label, self.cfg.DATASET.N_CLASS)) corr_pred = self.Image_generator(image, feats) preds = np.argmax(corr_pred.cpu().detach().numpy().copy(), axis=1) target = new_label.cpu().detach().numpy().copy() self.running_metrics.update(target, preds) if j == 0: color_map1 = gen_color_map(preds[0, :]).astype( np.uint8) color_map2 = gen_color_map(preds[1, :]).astype( np.uint8) color_map = cv2.hconcat([color_map1, color_map2]) cv2.imwrite( os.path.join( self.val_outdir, '{}epoch*{}*{}.png'.format( epoch_i, meta[3][0], meta[3][1])), color_map) score = self.running_metrics.get_scores() oa = score['Overall Acc: \t'] precision = score['Precision: \t'][1] recall = score['Recall: \t'][1] iou = score['Class IoU: \t'][1] miou = score['Mean IoU: \t'] self.running_metrics.reset() epoch_num.append(epoch_i) all_val_epo_acc.append(oa) all_val_epo_iou.append(miou) vis.line(X=np.column_stack( np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)), Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)), opts={ 'legend': ['val epoch Overall Acc', 'val epoch Mean IoU'], 'linecolor': np.array([[255, 0, 0], [0, 255, 0]]), 'title': 'Validate Accuracy and IoU' }, win='validate Accuracy and IoU') log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \ .format(datetime.now(), epoch_i, oa, recall, miou) self.logger.info(log) state = { 'epoch': epoch_i, "acc": oa, "recall": recall, "iou": miou, 'model_G_N': self.OldLabel_generator.state_dict(), 'model_G_I': self.Image_generator.state_dict(), 'model_D': self.discriminator.state_dict(), 'optimizer_G': self.optimizer_G.state_dict(), 'optimizer_D': self.optimizer_D.state_dict() } save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints', '{}epoch.pth'.format(epoch_i)) torch.save(state, save_path)
num_workers=4, batch_size=1, shuffle=False) netG = Generator(upscale) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) generator_loss = GeneratorLoss() if torch.cuda.is_available(): netG.cuda() netD.cuda() generator_loss.cuda() optimizerG = optim.Adam(netG.parameters()) optimizerD = optim.Adam(netD.parameters()) results = { 'd_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [] } for epoch in range(1, epoch + 1): trainning_results = {
def main(args): if (not os.path.exists('data/dataset.pt')): # Make sure the bit depth is 24, 8 = Gray scale df = pd.read_pickle('data/dataset_files.gzip') df = df[(df['width'] > 100) & (df['height'] > 100)] train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=True) _, val_similar = dataframe_find_similar_images( val_df, batch_size=args.batch_size) # Create the train dataset train_filenames = train_df['filename'].tolist() train_set = TrainDatasetFromList(train_filenames, crop_size=args.crop_size, upscale_factor=args.upscale_factor) val_sets = list() for val_df in val_similar: val_filenames = val_df['filename'].tolist() val_set = ValDatasetFromList(val_filenames, upscale_factor=args.upscale_factor) val_sets.append(val_set) train_sampler = torch.utils.data.RandomSampler(train_set) val_sampler = torch.utils.data.SequentialSampler(val_set) data_to_save = { 'train_dataset': train_set, "val_datasets": val_sets, 'train_sampler': train_sampler, 'val_sampler': val_sampler } torch.save(data_to_save, 'data/dataset.pt') else: datasets = torch.load('data/dataset.pt') train_set = datasets['train_dataset'] val_sets = datasets['val_datasets'] train_sampler = datasets['train_sampler'] val_sampler = datasets['val_sampler'] train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, num_workers=args.num_workers, sampler=train_sampler) val_loaders = list() for val_set in val_sets: val_loaders.append( DataLoader(dataset=val_set, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)) netG = Generator(args.upscale_factor) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) generator_criterion = GeneratorLoss() if torch.cuda.is_available(): netG.cuda() netD.cuda() generator_criterion.cuda() optimizerG = optim.Adam(netG.parameters()) optimizerD = optim.Adam(netD.parameters()) results = { 'd_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [] } start_epoch = 1 if args.resume: import glob netG_files = glob.glob( os.path.join(args.output_dir, 'netG_epoch_%d_*.pth' % (args.upscale_factor))) netD_files = glob.glob( os.path.join(args.output_dir, 'netD_epoch_%d_*.pth' % (args.upscale_factor))) if (len(netG_files) > 0): netG_file = max(netG_files, key=os.path.getctime) netD_file = max(netD_files, key=os.path.getctime) netG.load_state_dict(torch.load(netG_file)) netD.load_state_dict(torch.load(netD_file)) start_epoch = len(netG_files) for epoch in range(start_epoch, args.epochs + 1): train_bar = tqdm(train_loader) running_results = { 'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0 } dscaler = torch.cuda.amp.GradScaler( ) # Creates once at the beginning of training #* Discriminator gscaler = torch.cuda.amp.GradScaler() #* Generator netG.train() netD.train() for data, target in train_bar: with torch.cuda.amp.autocast(): # Mix precision batch_size = data.size(0) running_results['batch_sizes'] += batch_size ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### netD.zero_grad() real_img = Variable(target, requires_grad=False) if torch.cuda.is_available(): real_img = real_img.cuda() z = Variable(data) if torch.cuda.is_available(): z = z.cuda() fake_img = netG(z) real_out = netD(real_img).mean( ) # Discriminator Takes in the real image and predicts whether it's real fake_out = netD(fake_img).mean( ) # Discriminator takes in the fake image and predicts if it's fake d_loss = 1 - real_out + fake_out # Minimizing the loss would mean real_out=1 and fake out = 0. so it knows the real image it knows the fake image # d_loss.backward(retain_graph=True) # optimizerD.step() ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### netG.zero_grad() g_loss = generator_criterion(fake_out.detach(), fake_img, real_img.detach()) dscaler.scale(d_loss).backward(retain_graph=True) gscaler.scale(g_loss).backward() dscaler.step(optimizerD) dscaler.update() gscaler.step(optimizerG) gscaler.update() fake_img = netG(z) fake_out = netD(fake_img).mean() # loss for current batch before optimization running_results['g_loss'] += g_loss.item() * batch_size running_results['d_loss'] += d_loss.item() * batch_size running_results['d_score'] += real_out.item() * batch_size running_results['g_score'] += fake_out.item() * batch_size train_bar.set_description( desc= '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (epoch, args.epochs, running_results['d_loss'] / running_results['batch_sizes'], running_results['g_loss'] / running_results['batch_sizes'], running_results['d_score'] / running_results['batch_sizes'], running_results['g_score'] / running_results['batch_sizes'])) # save model parameters torch.save( netG.state_dict(), os.path.join(args.output_dir, 'netG_epoch_%d_%d.pth' % (args.upscale_factor, epoch))) torch.save( netD.state_dict(), os.path.join(args.output_dir, 'netD_epoch_%d_%d.pth' % (args.upscale_factor, epoch))) if epoch % args.validation_epoch == 0 and epoch != 0: netG.eval() with torch.no_grad(): val_results = { 'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0 } val_images = [] for i in trange(len(val_loaders), desc='Running validation'): val_loader = val_loaders[i] for val_lr, val_hr_restore, val_hr in val_loader: batch_size = val_lr.size(0) val_results['batch_sizes'] += batch_size lr = val_lr hr = val_hr if torch.cuda.is_available(): lr = lr.cuda() hr = hr.cuda() sr = netG(lr) batch_mse = ((sr - hr)**2).data.mean() val_results['mse'] += batch_mse * batch_size batch_ssim = pytorch_ssim.ssim(sr, hr).item() val_results['ssims'] += batch_ssim * batch_size val_results['psnr'] = 10 * log10( (hr.max()**2) / (val_results['mse'] / val_results['batch_sizes'])) val_results['ssim'] = val_results[ 'ssims'] / val_results['batch_sizes'] # val_bar.set_description( # desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % ( # val_results['psnr'], val_results['ssim'])) # convert the validation images val_hr_restore_squeeze = val_hr_restore.squeeze(0) hr_squeeze = hr.data.cpu().squeeze(0) sr_squeeze = sr.data.cpu().squeeze(0) for b in range(batch_size): val_hr = val_hr_restore_squeeze[b] hr_temp = hr_squeeze[b] sr_temp = sr_squeeze[b] val_images.extend([ display_transform()(val_hr), display_transform()(hr_temp), display_transform()(sr_temp) ]) val_images = torch.stack(val_images) val_images = torch.chunk(val_images, val_images.size(0) // 15) val_save_bar = tqdm(val_images, desc='[saving training results]') index = 1 for image in val_save_bar: image = utils.make_grid(image, nrow=3, padding=5) utils.save_image( image, os.path.join( args.output_dir, 'epoch_%d_upscale_%d_index_%d.png' % (epoch, args.upscale_factor, index))) index += 1 # save loss\scores\psnr\ssim results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes']) results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes']) results['d_score'].append(running_results['d_score'] / running_results['batch_sizes']) results['g_score'].append(running_results['g_score'] / running_results['batch_sizes']) results['psnr'].append(val_results['psnr']) results['ssim'].append(val_results['ssim']) if epoch % 10 == 0 and epoch != 0: out_path = 'statistics/' data_frame = pd.DataFrame(data={ 'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'], 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim'] }, index=range(1, epoch + 1)) data_frame.to_csv(out_path + 'srf_' + str(args.upscale_factor) + '_train_results.csv', index_label='Epoch')
def train_half_pipe(): if DATA_AUG: train_set = TrainDatasetFromFolder('../data/split_SRdataset/train', hr_size=HR_SIZE, upscale_factor=UPSCALE_FACTOR) else: train_set = TrainDatasetFromFolder( '/home/mrey/ESA/Dataset/Step2-SuperresolutionWhale/converted_jpg', hr_size=HR_SIZE, upscale_factor=UPSCALE_FACTOR) val_set = ValDatasetFromFolder('../data/split_SRdataset/test', hr_size=HR_SIZE, upscale_factor=UPSCALE_FACTOR) train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) netG = Generator(UPSCALE_FACTOR) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) generator_criterion = GeneratorLoss() if torch.cuda.is_available(): netG.cuda() netD.cuda() generator_criterion.cuda() optimizerG = optim.Adam(netG.parameters()) optimizerD = optim.Adam(netD.parameters()) best_netG = copy.deepcopy(netG.state_dict()) best_netD = copy.deepcopy(netD.state_dict()) old_PSNR = 0.0 old_SSIM = 0.0 results = { 'd_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [] } for epoch in range(1, NUM_EPOCHS + 1): train_bar = tqdm(train_loader) running_results = { 'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0 } netG.train() netD.train() for data, target in train_bar: # print('') # print([data_i.shape for data_i in data.data]) # print([data_i.shape for data_i in target.data]) g_update_first = True batch_size = data.size(0) running_results['batch_sizes'] += batch_size ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = Variable(data) if torch.cuda.is_available(): z = z.cuda() fake_img = netG(z) netD.zero_grad() real_out = netD(real_img).mean() fake_out = netD(fake_img).mean() d_loss = 1 - real_out + fake_out d_loss.backward(retain_graph=True) optimizerD.step() ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### netG.zero_grad() g_loss = generator_criterion(fake_out, fake_img, real_img, 0, 0.0) g_loss.backward() optimizerG.step() fake_img = netG(z) fake_out = netD(fake_img).mean() g_loss = generator_criterion(fake_out, fake_img, real_img, 0, 0.0) running_results['g_loss'] += g_loss.data * batch_size d_loss = 1 - real_out + fake_out running_results['d_loss'] += d_loss.data * batch_size running_results['d_score'] += real_out.data * batch_size running_results['g_score'] += fake_out.data * batch_size train_bar.set_description( desc= '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'], running_results['g_loss'] / running_results['batch_sizes'], running_results['d_score'] / running_results['batch_sizes'], running_results['g_score'] / running_results['batch_sizes'])) netG.eval() out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/' if not os.path.exists(out_path): os.makedirs(out_path) val_bar = tqdm(val_loader) valing_results = { 'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0 } val_images = [] for val_lr, val_hr_restore, val_hr in val_bar: batch_size = val_lr.size(0) valing_results['batch_sizes'] += batch_size lr = Variable(val_lr, volatile=True) hr = Variable(val_hr, volatile=True) if torch.cuda.is_available(): lr = lr.cuda() hr = hr.cuda() sr = netG(lr) batch_mse = ((sr - hr)**2).data.mean() valing_results['mse'] += batch_mse * batch_size batch_ssim = pytorch_ssim.ssim(sr, hr).data valing_results['ssims'] += batch_ssim * batch_size valing_results['psnr'] = 10 * log10( 1 / (valing_results['mse'] / valing_results['batch_sizes'])) valing_results['ssim'] = valing_results['ssims'] / valing_results[ 'batch_sizes'] val_bar.set_description( desc= '[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (valing_results['psnr'], valing_results['ssim'])) val_images.extend([ display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)), display_transform()(sr.data.cpu().squeeze(0)) ]) # val_images = torch.stack(val_images) # val_images = torch.chunk(val_images, val_images.size(0) // 15) # val_save_bar = tqdm(val_images, desc='[saving training results]') # index = 1 # for image in val_save_bar: # image = utils.make_grid(image, nrow=3, padding=5) # if DATA_AUG: # utils.save_image(image, out_path + 'dataAug_epoch_%d_index_%d.png' % (epoch, index), padding=5) # else: # utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5) # index += 1 # SAVE model parameters if valing_results['psnr'] > old_PSNR and valing_results[ 'ssim'] > old_SSIM: old_PSNR = valing_results['psnr'] old_SSIM = valing_results['ssim'] best_netG = copy.deepcopy(netG.state_dict()) best_netD = copy.deepcopy(netD.state_dict()) out_folder = 'epochs/weights_' + str( UPSCALE_FACTOR) + '_dataAug_halfPipe/' if not os.path.exists(out_folder): os.makedirs(out_folder) torch.save(best_netG, out_folder + 'best_netG.pth') torch.save(best_netD, out_folder + 'best_netD.pth') if epoch % 10 == 0 and epoch != 0: if DATA_AUG: out_folder = 'epochs/weights_' + str( UPSCALE_FACTOR) + '_dataAug_halfPipe/' if not os.path.exists(out_folder): os.makedirs(out_folder) torch.save( netG.state_dict(), out_folder + 'netG_dataAug_epoch_%d_%03d.pth' % (UPSCALE_FACTOR, epoch)) torch.save( netD.state_dict(), out_folder + 'netD_dataAug_epoch_%d_%03d.pth' % (UPSCALE_FACTOR, epoch)) else: out_folder = 'epochs/weights_' + str( UPSCALE_FACTOR) + '_halfPipe/' if not os.path.exists(out_folder): os.makedirs(out_folder) torch.save( netG.state_dict(), out_folder + 'netG_epoch_%d_%03d.pth' % (UPSCALE_FACTOR, epoch)) torch.save( netD.state_dict(), out_folder + 'netD_epoch_%d_%03d.pth' % (UPSCALE_FACTOR, epoch)) # save loss\scores\psnr\ssim results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes']) results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes']) results['d_score'].append(running_results['d_score'] / running_results['batch_sizes']) results['g_score'].append(running_results['g_score'] / running_results['batch_sizes']) results['psnr'].append(valing_results['psnr']) results['ssim'].append(valing_results['ssim']) if epoch % 10 == 0 and epoch != 0: out_path = 'statistics_halfPipe/' if not os.path.exists(out_path): os.makedirs(out_path) data_frame = pd.DataFrame(data={ 'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'], 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim'] }, index=range(1, epoch + 1)) if DATA_AUG: data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_dataAug_train_results.csv', index_label='Epoch') else: data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
def train_lambda_class(): data_dir_lr = '../data/split_dataset/' data_dir_hr = '../data/split_dataset_SRGAN_' + str(UPSCALE_FACTOR) + os.sep train_set = ImageFolderWithPaths_train(data_dir_hr + "train" + os.sep, data_dir_lr + "train", HR_SIZE, UPSCALE_FACTOR) val_set = ImageFolderWithPaths_val(data_dir_hr + "test" + os.sep, data_dir_lr + "test", HR_SIZE, UPSCALE_FACTOR) train_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=True, num_workers=4) val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) class_names = train_set.classes print(train_set.classes) print(val_set.classes) netD_weights = "epochs/weights_halfPipe/weights_"+str(UPSCALE_FACTOR)\ + "_dataAug/netD_dataAug_epoch_"+str(UPSCALE_FACTOR)+"_050.pth" netG_weights = "epochs/weights_halfPipe/weights_" + str(UPSCALE_FACTOR) \ + "_dataAug/netG_dataAug_epoch_" + str(UPSCALE_FACTOR) + "_050.pth" # netD_weights = "epochs/weights_halfPipe/weights_" + str(UPSCALE_FACTOR) + "_dataAug/best_netD.pth" # netG_weights = "epochs/weights_halfPipe/weights_" + str(UPSCALE_FACTOR) + "_dataAug/best_netG.pth" netG = Generator(UPSCALE_FACTOR) netG.load_state_dict(torch.load(netG_weights)) netD = Discriminator() netD.load_state_dict(torch.load(netD_weights)) generator_criterion = GeneratorLoss() if torch.cuda.is_available(): netG.cuda() netD.cuda() generator_criterion.cuda() optimizerG = optim.Adam(netG.parameters()) optimizerD = optim.Adam(netD.parameters()) ####CLASSIFIER classifier = models.resnet50(pretrained=True) num_ftrs = classifier.fc.in_features classifier.fc = nn.Linear(num_ftrs, len(class_names)) classifier.name = 'resnet50' classifier.cuda() # if UPSCALE_FACTOR == 1: # weights_class_path = "/home/mrey/ESA/pruebas/multiclass_classification/weights/best_model.pth" # else: # weights_class_path = "/home/mrey/ESA/pruebas/multiclass_classification/weights/best_model_SRGAN_"\ # + str(UPSCALE_FACTOR)+".pth" weights_class_path = "../data/multiclass_classification/weights/"+\ classifier.name+"_best_model_kfold.pth" classifier.load_state_dict(torch.load(weights_class_path)) criterion_classifier = nn.CrossEntropyLoss() classifier.eval() print(classifier.name) print("upscale factor %d" % UPSCALE_FACTOR) print("HR size %d" % HR_SIZE) ############# best_netG = copy.deepcopy(netG.state_dict()) best_netD = copy.deepcopy(netD.state_dict()) old_PSNR = 0 old_SSIM = 0 basic_netG = copy.deepcopy(netG.state_dict()) basic_netD = copy.deepcopy(netD.state_dict()) miss_classifications = [] lambda_values = [2, 1, 0.1, 0.01] for lambda_class in lambda_values: results = { 'd_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [] } netG.load_state_dict(basic_netG) netD.load_state_dict(basic_netD) print("CHOSE LAMBDA {}".format(lambda_class)) for epoch in range(1, NUM_EPOCHS + 1): train_bar = tqdm(train_loader) running_results = { 'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0 } netG.train() netD.train() for data, target, label in train_bar: # print('') # print([data_i.shape for data_i in data.data]) # print([data_i.shape for data_i in target.data]) g_update_first = True batch_size = data.size(0) running_results['batch_sizes'] += batch_size ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() label = label.to(device) z = Variable(data) if torch.cuda.is_available(): z = z.cuda() fake_img = netG(z) classifier_outputs = classifier(z) _, preds = torch.max(classifier_outputs, 1) if preds != label: miss_classifications.append([preds, label]) # print(str(preds)+"--"+str(label)) loss_classifier = criterion_classifier(classifier_outputs, label) netD.zero_grad() real_out = netD(real_img).mean() fake_out = netD(fake_img).mean() d_loss = 1 - real_out + fake_out d_loss.backward(retain_graph=True) optimizerD.step() ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### netG.zero_grad() g_loss = generator_criterion(fake_out, fake_img, real_img, loss_classifier, float(lambda_class)) g_loss.backward() optimizerG.step() fake_img = netG(z) fake_out = netD(fake_img).mean() g_loss = generator_criterion(fake_out, fake_img, real_img, loss_classifier, float(lambda_class)) running_results['g_loss'] += g_loss.data * batch_size d_loss = 1 - real_out + fake_out running_results['d_loss'] += d_loss.data * batch_size running_results['d_score'] += real_out.data * batch_size running_results['g_score'] += fake_out.data * batch_size train_bar.set_description( desc= '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'], running_results['g_loss'] / running_results['batch_sizes'], running_results['d_score'] / running_results['batch_sizes'], running_results['g_score'] / running_results['batch_sizes'])) netG.eval() out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/' if not os.path.exists(out_path): os.makedirs(out_path) val_bar = tqdm(val_loader) valing_results = { 'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0 } val_images = [] for val_lr, val_hr_restore, val_hr in val_bar: batch_size = val_lr.size(0) valing_results['batch_sizes'] += batch_size lr = Variable(val_lr, volatile=True) hr = Variable(val_hr, volatile=True) if torch.cuda.is_available(): lr = lr.cuda() hr = hr.cuda() sr = netG(lr) batch_mse = ((sr - hr)**2).data.mean() valing_results['mse'] += batch_mse * batch_size batch_ssim = pytorch_ssim.ssim(sr, hr).data valing_results['ssims'] += batch_ssim * batch_size valing_results['psnr'] = 10 * log10( 1 / (valing_results['mse'] / valing_results['batch_sizes'])) valing_results['ssim'] = valing_results[ 'ssims'] / valing_results['batch_sizes'] val_bar.set_description( desc= '[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (valing_results['psnr'], valing_results['ssim'])) val_images.extend([ display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)), display_transform()(sr.data.cpu().squeeze(0)) ]) # val_images = torch.stack(val_images) # val_images = torch.chunk(val_images, val_images.size(0) // 15) # val_save_bar = tqdm(val_images, desc='[saving training results]') # index = 1 # for image in val_save_bar: # image = utils.make_grid(image, nrow=3, padding=5) # if DATA_AUG: # utils.save_image(image, out_path + 'dataAug_epoch_%d_index_%d.png' % (epoch, index), padding=5) # else: # utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5) # index += 1 # SAVE model parameters if epoch % 5 == 0 and epoch != 0: out_folder = 'epochs/weights_' + str(UPSCALE_FACTOR)+'_'+str(classifier.name)+'_lambda' + \ str(lambda_class)+'_wholePipe/' if not os.path.exists(out_folder): os.makedirs(out_folder) torch.save( netG.state_dict(), out_folder + 'netG_epoch_%d_%03d.pth' % (UPSCALE_FACTOR, epoch)) torch.save( netD.state_dict(), out_folder + 'netD_epoch_%d_%03d.pth' % (UPSCALE_FACTOR, epoch)) # save loss\scores\psnr\ssim results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes']) results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes']) results['d_score'].append(running_results['d_score'] / running_results['batch_sizes']) results['g_score'].append(running_results['g_score'] / running_results['batch_sizes']) results['psnr'].append(valing_results['psnr']) results['ssim'].append(valing_results['ssim']) if epoch % 5 == 0 and epoch != 0: out_path = 'statistics_wholePipe/' if not os.path.exists(out_path): os.makedirs(out_path) data_frame = pd.DataFrame(data={ 'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'], 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim'] }, index=range(1, epoch + 1)) data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_' + str(classifier.name) + '_lambda' + str(lambda_class) + '_train_results.csv', index_label='Epoch')
def main(step, dataset, data_dir, data_dir_bias, model_name): global args, model, netContent, lr args = parser.parse_args() lr = args.lr args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) netG = GeneratorMM(args.upscale_factor) n_parameters = sum([p.data.nelement() for p in netG.parameters()]) print(' + Number of params: {}'.format(n_parameters)) netD = DiscriminatorMM() n_parameters = sum([p.data.nelement() for p in netD.parameters()]) print(' + Number of params: {}'.format(n_parameters)) generator_criterion = GeneratorLoss() netG.set_multiple_gpus() netD.set_multiple_gpus() if step > 0: model_dir = data_dir + '/model/modelG_' + str(step) + '.pkl' netG.load_state_dict(torch.load(model_dir)) model_dir = data_dir + '/model/modelD_' + str(step) + '.pkl' netD.load_state_dict(torch.load(model_dir)) if args.cuda: netG = netG.cuda() netD = netD.cuda() generator_criterion = generator_criterion.cuda() cudnn.benchmark = True optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(0.9, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(0.9, 0.999)) # Load the dataset kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( ShepardMetzler(root_dir=data_dir_bias + '/torch_super/' + model_name + '/train/' + '/bias_0/'), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( ShepardMetzler(root_dir=data_dir_bias + '/torch_super/' + model_name + '/test/' + '/bias_0/'), batch_size=args.test_batch_size, shuffle=False, **kwargs) lRecord = [] generator_loss_train = [] discriminator_loss_train = [] a_loss_train = [] p_loss_train = [] i_loss_train = [] t_loss_train = [] generator_loss_test = [] discriminator_loss_test = [] a_loss_test = [] p_loss_test = [] i_loss_test = [] t_loss_test = [] start = 0 for epoch in range(step + 1, args.epochs + step + 1): generator_loss, a_loss, p_loss, i_loss, t_loss, discriminator_loss = train( train_loader, optimizerG, optimizerD, netG, netD, generator_criterion, epoch, lRecord) generator_loss_train.append(generator_loss) a_loss_train.append(a_loss) p_loss_train.append(p_loss) i_loss_train.append(i_loss) t_loss_train.append(t_loss) discriminator_loss_train.append(discriminator_loss) lr = adjust_learning_rate(optimizerG, epoch - 1) for param_group in optimizerG.param_groups: param_group["lr"] = lr lr = adjust_learning_rate(optimizerD, epoch - 1) for param_group in optimizerD.param_groups: param_group["lr"] = lr if epoch % args.log_interval_test == 0: test_dir = data_dir + '/test/' + 'model' + str( epoch) + '_scene' + str(start + 1) + '/' if os.path.exists(test_dir) == False: os.mkdir(test_dir) generator_loss, a_loss, p_loss, i_loss, t_loss, discriminator_loss = test( netG, netD, start, test_loader, epoch, generator_criterion, lRecord, test_dir) start = (start + 1) % len(test_loader) generator_loss_test.append(generator_loss) a_loss_test.append(a_loss) p_loss_test.append(p_loss) i_loss_test.append(i_loss) t_loss_test.append(t_loss) discriminator_loss_test.append(discriminator_loss) if epoch % args.log_interval_record == 0: SaveRecord(data_dir, epoch, netG, netD, generator_loss_train, a_loss_train, p_loss_train, i_loss_train, t_loss_train, discriminator_loss_train, generator_loss_test, a_loss_test, p_loss_test, i_loss_test, t_loss_test, discriminator_loss_test, lRecord)
def main_train(path_trn: str, path_val: str, crop_size: int, upscale_factor: int, num_epochs: int, num_workers: int, to_device: str = 'cuda:0', batch_size: int = 64): to_device = get_device(to_device) train_set = TrainDatasetFromFolder(path_trn, crop_size=crop_size, upscale_factor=upscale_factor) val_set = ValDatasetFromFolder(path_val, upscale_factor=upscale_factor) # train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=crop_size, upscale_factor=upscale_factor) # val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=upscale_factor) # train_loader = DataLoader(dataset=train_set, num_workers=num_workers, batch_size=batch_size, shuffle=True) val_loader = DataLoader(dataset=val_set, num_workers=num_workers, batch_size=1, shuffle=False) netG = Generator(upscale_factor) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) generator_criterion = GeneratorLoss() if torch.cuda.is_available(): netG.cuda() netD.cuda() generator_criterion.cuda() optimizerG = optim.Adam(netG.parameters()) optimizerD = optim.Adam(netD.parameters()) results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []} for epoch in range(1, num_epochs + 1): train_bar = tqdm(train_loader) running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0} netG.train() netD.train() # FIXME: seperate function for epoch training for data, target in train_bar: g_update_first = True batch_size = data.size(0) # # img_hr = target.numpy().transpose((0, 2, 3, 1))[0] # img_lr = data.numpy().transpose((0, 2, 3, 1))[0] # img_lr_x4 = cv2.resize(img_lr, img_hr.shape[:2], interpolation=cv2.INTER_CUBIC) # # # plt.subplot(1, 3, 1) # plt.imshow(img_hr) # plt.subplot(1, 3, 2) # plt.imshow(img_lr) # plt.subplot(1, 3, 3) # plt.imshow(img_lr_x4) # plt.show() running_results['batch_sizes'] += batch_size ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### # real_img = Variable(target) # if torch.cuda.is_available(): # real_img = real_img.cuda() # z = Variable(data) # if torch.cuda.is_available(): # z = z.cuda() z = data.to(to_device) real_img = target.to(to_device) fake_img = netG(z) netD.zero_grad() real_out = netD(real_img).mean() fake_out = netD(fake_img).mean() d_loss = 1 - real_out + fake_out d_loss.backward(retain_graph=True) optimizerD.step() ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### netG.zero_grad() g_loss = generator_criterion(fake_out, fake_img, real_img) g_loss.backward() optimizerG.step() fake_img = netG(z) fake_out = netD(fake_img).mean() g_loss = generator_criterion(fake_out, fake_img, real_img) running_results['g_loss'] += float(g_loss) * batch_size d_loss = 1 - real_out + fake_out running_results['d_loss'] += float(d_loss) * batch_size running_results['d_score'] += float(real_out) * batch_size running_results['g_score'] += float(fake_out) * batch_size train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % ( epoch, num_epochs, running_results['d_loss'] / running_results['batch_sizes'], running_results['g_loss'] / running_results['batch_sizes'], running_results['d_score'] / running_results['batch_sizes'], running_results['g_score'] / running_results['batch_sizes'])) netG.eval() #FIXME: seperate function for epoch validation with torch.no_grad(): out_path = 'training_results/SRF_' + str(upscale_factor) + '/' if not os.path.exists(out_path): os.makedirs(out_path) val_bar = tqdm(val_loader) valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0} val_images = [] for val_lr, val_hr_restore, val_hr in val_bar: batch_size = val_lr.size(0) valing_results['batch_sizes'] += batch_size # lr = Variable(val_lr, volatile=True) # hr = Variable(val_hr, volatile=True) # if torch.cuda.is_available(): # lr = lr.cuda() # hr = hr.cuda() lr = val_lr.to(to_device) hr = val_hr.to(to_device) sr = netG(lr) batch_mse = ((sr - hr) ** 2).mean() valing_results['mse'] += float(batch_mse) * batch_size batch_ssim = float(pytorch_ssim.ssim(sr, hr)) #.data[0] valing_results['ssims'] += batch_ssim * batch_size valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes'])) valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes'] val_bar.set_description( desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % ( valing_results['psnr'], valing_results['ssim'])) val_images.extend( [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)), display_transform()(sr.data.cpu().squeeze(0))]) val_images = torch.stack(val_images) val_images = torch.chunk(val_images, val_images.size(0) // 15) val_save_bar = tqdm(val_images, desc='[saving training results]') index = 1 for image in val_save_bar: image = utils.make_grid(image, nrow=3, padding=5) utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5) index += 1 # save model parameters torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (upscale_factor, epoch)) torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (upscale_factor, epoch)) # save loss\scores\psnr\ssim results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes']) results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes']) results['d_score'].append(running_results['d_score'] / running_results['batch_sizes']) results['g_score'].append(running_results['g_score'] / running_results['batch_sizes']) results['psnr'].append(valing_results['psnr']) results['ssim'].append(valing_results['ssim']) if epoch % 10 == 0 and epoch != 0: out_path = 'statistics/' data_frame = pd.DataFrame( data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'], 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']}, index=range(1, epoch + 1)) data_frame.to_csv(out_path + 'srf_' + str(upscale_factor) + '_train_results.csv', index_label='Epoch')
def main(): os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3' train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR) train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True) val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) netG = Generator(UPSCALE_FACTOR) print('# generator parameters:', sum(param.numel() for param in netG.parameters())) # Generator의 총 parameter 수 netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) # Discriminator의 총 parameter 수 generator_criterion = GeneratorLoss() # loss function netG = nn.DataParallel(netG).cuda() netD = nn.DataParallel(netD).cuda() #netG = netG.cuda() #netD = netD.cuda() generator_criterion.cuda() optimizerG = optim.Adam(netG.parameters()) # optimizer : adam optimizerD = optim.Adam(netD.parameters()) # optimizer : adam results = { 'd_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [] } for epoch in range(1, NUM_EPOCHS + 1): # model train d_loss, g_loss, d_score, g_score = train(netG, netD, generator_criterion, optimizerG, optimizerD, train_loader, epoch) # validation data acc psnr, ssim = test(netG, netD, val_loader, epoch) # save loss\scores\psnr\ssim results['d_loss'].append(d_loss) results['g_loss'].append(g_loss) results['d_score'].append(d_score) results['g_score'].append(g_score) results['psnr'].append(psnr) results['ssim'].append(ssim) # save results if epoch % 10 == 0 and epoch != 0: out_path = 'statistics/' data_frame = pd.DataFrame(data={ 'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'], 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim'] }, index=range(1, epoch + 1)) data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch') print()