def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): downsampler = None if stride != 1 and downsample_mode != 'stride': if downsample_mode == 'avg': downsampler = nn.AvgPool2d(stride, stride) elif downsample_mode == 'max': downsampler = nn.MaxPool2d(stride, stride) elif downsample_mode in ['lanczos2', 'lanczos3']: downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) else: assert False stride = 1 padder = None to_pad = int((kernel_size - 1) / 2) if pad == 'reflection': padder = nn.ReflectionPad2d(to_pad) to_pad = 0 convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) return nn.Sequential(*layers)
NET_TYPE, pad, skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, upsample_mode='bilinear').type(dtype) # Losses mse = torch.nn.MSELoss().type(dtype) img_LR_var = np_to_var(imgs['LR_np']).type(dtype) downsampler = Downsampler(n_planes=3, factor=factor, kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype) def closure(): global i if reg_noise_std > 0: net_input.data = net_input_saved + \ (noise.normal_() * reg_noise_std) out_HR = net(net_input) # ([1, 3, H, W]) out_LR = downsampler(out_HR) # ([1, 3, H/factor, W/factor]) total_loss = mse(out_LR, img_LR_var) # total_loss = CharbonnierLoss(out_LR, img_LR_var)
def __init__(self, config): self.rank, self.world_size = 0, 1 if config['dist']: self.rank = dist.get_rank() self.world_size = dist.get_world_size() self.config = config self.mode = config['dgp_mode'] self.custom_mask = config['custom_mask'] self.update_G = config['update_G'] self.update_embed = config['update_embed'] self.iterations = config['iterations'] self.ftr_num = config['ftr_num'] self.ft_num = config['ft_num'] self.lr_ratio = config['lr_ratio'] self.G_lrs = config['G_lrs'] self.z_lrs = config['z_lrs'] self.use_in = config['use_in'] self.select_num = config['select_num'] self.factor = 4 # Downsample factor self.mask_path = config['mask_path'] #Create selective masking if self.custom_mask: self.mask = torch.ones(1, 1, 256, 256).cuda() x = Image.open(self.mask_path) pil_to_tensor = transforms.ToTensor()(x).unsqueeze_(0) t = Variable(torch.Tensor([0.9])) # threshold final_mask = F.interpolate(pil_to_tensor, size=(256, 256), mode='bilinear') self.mask = (final_mask > t).float() * 1 self.mask = self.mask[0][0].cuda() self.regions = self.get_regions(self.mask) ######################### # create model self.G = models.Generator(**config).cuda() self.D = models.Discriminator( **config).cuda() if config['ftr_type'] == 'Discriminator' else None self.G.optim = torch.optim.Adam( [{ 'params': self.G.get_params(i, self.update_embed) } for i in range(len(self.G.blocks) + 1)], lr=config['G_lr'], betas=(config['G_B1'], config['G_B2']), weight_decay=0, eps=1e-8) # load weights if config['random_G']: self.random_G() else: utils.load_weights(self.G if not (config['use_ema']) else None, self.D, config['weights_root'], name_suffix=config['load_weights'], G_ema=self.G if config['use_ema'] else None, strict=False) self.G.eval() if self.D is not None: self.D.eval() self.G_weight = deepcopy(self.G.state_dict()) # prepare latent variable and optimizer self._prepare_latent() # prepare learning rate scheduler self.G_scheduler = utils.LRScheduler(self.G.optim, config['warm_up']) self.z_scheduler = utils.LRScheduler(self.z_optim, config['warm_up']) # loss functions self.mse = torch.nn.MSELoss() if config['ftr_type'] == 'Discriminator': self.ftr_net = self.D self.criterion = utils.DiscriminatorLoss( ftr_num=config['ftr_num'][0]) else: vgg = torchvision.models.vgg16(pretrained=True).cuda().eval() self.ftr_net = models.subsequence(vgg.features, last_layer='20') self.criterion = utils.PerceptLoss() # Downsampler for producing low-resolution image self.downsampler = Downsampler(n_planes=3, factor=self.factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).type( torch.cuda.FloatTensor)
def build_closure(writer, dtype): # Read config file config = common.configparser.ConfigParser() config.sections() config.read('config.cfg') plot_steps_low = config.getint('DEFAULT', 'plot_steps_low') plot_steps_high = config.getint('DEFAULT', 'plot_steps_high') blur = (config['DEFAULT']['blur'] != 'none') mse = torch.nn.MSELoss().type(dtype) downsampler = Downsampler(n_planes=config.getint('DEFAULT', 'n_channels'), factor=4, kernel_type='lanczos2', phase=0.5, preserve_size=True).type(dtype) loss_network = None augmented_history = config.has_option('LOADING', 'augmented_history') and \ config.getboolean('LOADING', 'augmented_history') ignore_ground_truth = config.has_option('LOADING', 'ignore_ground_truth') and \ config.getboolean('LOADING', 'ignore_ground_truth') random_swap = config.has_option('DEFAULT', 'random_swap') and \ config.getboolean('DEFAULT', 'random_swap') if random_swap: random.seed(101) if config.has_option('DEFAULT', 'noise_level'): noise_level = float(config['DEFAULT']['noise_level']) else: noise_level = 0.03 """ if config.getboolean('DEFAULT', 'use_perceptual_loss'): vgg_model = vgg.vgg16(pretrained=True) if torch.cuda.is_available(): vgg_model.cuda() loss_network = LossNetwork(vgg_model) loss_network.eval() # Pre-trained model expects normalized, RGB images. Here's a transform. grey_to_normal_rgb = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Lambda(lambda x: x.unsqueeze(0)) ]) noise = config['DEFAULT']['added_noise'] != 'none' if noise: noise_x = config.getint('DEFAULT', 'noise_x') noise_y = config.getint('DEFAULT', 'noise_y') background = sr_utils.get_background(noise_x, noise_y)""" def get_loss(out_LR, ground_truth_LR, blurred_LR): """Calculates loss from the low-resolution output of the network. """ if blur: used_image = blurred_LR else: used_image = ground_truth_LR if loss_network: out_LR = out_LR.repeat(3, 1, 1) used_image = used_image.repeat(3, 1, 1) total_loss = mse( loss_network(grey_to_normal_rgb(out_LR)).relu3_3, loss_network(grey_to_normal_rgb(used_image)).relu3_3) else: total_loss = mse(out_LR, used_image) return total_loss def get_images(before_HR, after_HR, bicubic_HR, blurred_HR, before_LR, after_LR, blurred_LR): HR_grid = torchvision.utils.make_grid([ torch.clamp(torch.flip(blurred_HR, [1, 0]), 0, 1), torch.clamp(torch.flip(before_HR, [1, 0]), 0, 1), torch.clamp(torch.flip(after_HR, [1, 0]), 0, 1), torch.clamp(torch.flip(bicubic_HR, [1, 0]), 0, 1), torch.clamp( torch.flip(common.makeResidual(after_HR, before_HR), [1, 0]), 0, 1) ], 5) LR_grid = torchvision.utils.make_grid([ torch.clamp(torch.flip(blurred_LR, [1, 0]), 0, 1), torch.clamp(torch.flip(before_LR, [1, 0]), 0, 1), torch.clamp(torch.flip(after_LR, [1, 0]), 0, 1), torch.clamp( torch.flip(common.makeResidual(after_LR, before_LR), [1, 0]), 0, 1) ], 4) return HR_grid, LR_grid def closure(): # Train with a different input/output pair at each iteration. if random_swap: index = random.randint(0, len(state.imgs) - 1) else: index = state.i % len(state.imgs) net_input = state.imgs[index]['net_input'] ground_truth_LR = state.imgs[index]['LR_torch'] ground_truth_HR = state.imgs[index]['HR_torch'] bicubic_HR = state.imgs[index]['HR_torch_bicubic'] #blurred_HR = TF.to_tensor(state.imgs[index]['HR_pil_blurred']).type(state.dtype) #blurred_LR = TF.to_tensor(state.imgs[index]['LR_pil_blurred']).type(state.dtype) blurred_LR = ground_truth_LR blurred_HR = ground_truth_HR background = None noise = None # Feed through actual network net_input_noisy = net_input.detach().clone() noise_t = net_input.detach().clone() net_input_noisy = net_input_noisy + (noise_t.normal_() * noise_level) out_HR = state.net(net_input_noisy) if noise: with torch.no_grad(): out_HR = out_HR + background * torch.randn(out_HR.size()).type( state.dtype) out_LR = downsampler(out_HR) out_HR = out_HR.squeeze(0) out_LR = out_LR.squeeze(0) # Get loss and train total_loss = get_loss(out_LR, ground_truth_LR, blurred_LR) total_loss.backward() out_HR = out_HR.detach().cpu() out_LR = out_LR.detach().cpu() ground_truth_LR = ground_truth_LR.cpu() ground_truth_HR = ground_truth_HR.cpu() bicubic_HR = bicubic_HR.cpu() blurred_LR = blurred_LR.cpu() blurred_HR = blurred_HR.cpu() if (state.i % plot_steps_low < len(state.imgs)): psnr_LR = compare_psnr(common.torch_to_np(ground_truth_LR), common.torch_to_np(out_LR)) if not ignore_ground_truth: psnr_HR = compare_psnr(common.torch_to_np(ground_truth_HR), common.torch_to_np(out_HR)) target_loss = sr_utils.compare_HR( common.torch_to_np(ground_truth_HR), common.torch_to_np(out_HR)) state.imgs[index]['history_low'].psnr_HR.append(psnr_HR) state.imgs[index]['history_low'].target_loss.append( target_loss) else: psnr_HR = 0 target_loss = 0 state.imgs[index]['history_low'].iteration.append(state.i) state.imgs[index]['history_low'].psnr_LR.append(psnr_LR) state.imgs[index]['history_low'].training_loss.append( total_loss.item()) if augmented_history: HR_torch_blurred = state.imgs[index][ 'HR_torch_blurred'].detach().cpu() LR_torch_downsampled = state.imgs[index][ 'LR_torch_downsampled'].detach().cpu() psnr_blurred = compare_psnr( common.torch_to_np(HR_torch_blurred), common.torch_to_np(out_HR)) psnr_downsampled = compare_psnr( common.torch_to_np(LR_torch_downsampled), common.torch_to_np(out_LR)) state.imgs[index]['history_low'].psnr_blurred.append( psnr_blurred) state.imgs[index]['history_low'].psnr_downsampled.append( psnr_downsampled) print("{} {} {} {} {} {}".format(state.i, index, psnr_LR, psnr_HR, psnr_blurred, psnr_downsampled)) elif not ignore_ground_truth: print("{} {} {} {}".format(state.i, index, psnr_LR, psnr_HR)) else: print("{} {} {} {}".format(state.i, index, psnr_LR, total_loss.item())) # TensorBoard History writer.add_scalar('PSNR LR', psnr_LR, state.i) writer.add_scalar('PSNR HR', psnr_HR, state.i) writer.add_scalar('Training Loss', total_loss.item(), state.i) writer.add_scalar('Target Loss', target_loss, state.i) if (state.i % plot_steps_high < len(state.imgs)): # Lower frequency capturing of large data psnr_LR = compare_psnr(common.torch_to_np(ground_truth_LR), common.torch_to_np(out_LR)) if not ignore_ground_truth: psnr_HR = compare_psnr(common.torch_to_np(ground_truth_HR), common.torch_to_np(out_HR)) target_loss = sr_utils.compare_HR( common.torch_to_np(ground_truth_HR), common.torch_to_np(out_HR)) state.imgs[index]['history_high'].psnr_HR.append(psnr_HR) state.imgs[index]['history_high'].target_loss.append( target_loss) else: psnr_HR = 0 target_loss = 0 state.imgs[index]['history_high'].iteration.append(state.i) state.imgs[index]['history_high'].psnr_LR.append(psnr_LR) state.imgs[index]['history_high'].training_loss.append( total_loss.item()) if augmented_history: HR_torch_blurred = state.imgs[index][ 'HR_torch_blurred'].detach().cpu() LR_torch_downsampled = state.imgs[index][ 'LR_torch_downsampled'].detach().cpu() psnr_blurred = compare_psnr( common.torch_to_np(HR_torch_blurred), common.torch_to_np(out_HR)) psnr_downsampled = compare_psnr( common.torch_to_np(LR_torch_downsampled), common.torch_to_np(out_LR)) state.imgs[index]['history_high'].psnr_blurred.append( psnr_blurred) state.imgs[index]['history_high'].psnr_downsampled.append( psnr_downsampled) # Save parameters for name, param in state.net.named_parameters(): writer.add_histogram('Parameter {}'.format(name), param.flatten(), state.i) # Add images HR_grid, LR_grid = get_images(ground_truth_HR, out_HR, bicubic_HR, blurred_HR, ground_truth_LR, out_LR, blurred_LR) writer.add_image('Network LR Output', LR_grid, state.i) writer.add_image('Network HR Output', HR_grid, state.i) # Save a checkpoint #torch.save(state.net, "./output/checkpoints/checkpoint_{}.pt".format(state.i)) #common.saveFigure("./output/checkpoints/checkpoint_{}.png".format(state.i), common.torch_to_np(out_HR)) sr_utils.make_progress_figure( ground_truth_HR, bicubic_HR, out_HR, ground_truth_LR, out_LR, 'HR_update_frame_{}_{}'.format(index, state.i), 'LR_update_frame_{}_{}'.format(index, state.i)) state.i += 1 return total_loss return closure
def train(args, imgs): # Load model net = skip(num_input_channels=32, num_output_channels=3, num_channels_down=[64, 64, 64, 128, 128], num_channels_up=[64, 64, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], upsample_mode='bilinear', need_sigmoid=True, need_bias=True, pad='reflection', act_fun='LeakyReLU') net = net.float() net = net.cuda() if torch.cuda.is_available() else net # Compute the number of parameters s = sum([np.prod(list(p.size())) for p in net.parameters()]) print('Number of parameters: ', s) criterion = nn.MSELoss() # define input net_input = get_noise(32, 'noise', (np.shape( imgs['HR_np'])[1], np.shape(imgs['HR_np'])[2])).float().detach() img_LR_var = np_to_var(imgs['LR_np']).float() downsampler = Downsampler(n_planes=3, factor=args.factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).float() downsampler = downsampler.cuda() if torch.cuda.is_available( ) else downsampler net_input_saved = net_input.data.clone() noise = net_input.data.clone() # Define closure reg_noise_std = 0.0 tv_weight = 0.0 net_input = net_input.cuda() if torch.cuda.is_available() else net_input img_LR_var = img_LR_var.cuda() if torch.cuda.is_available() else img_LR_var net_input_saved = net_input_saved.cuda() if torch.cuda.is_available( ) else net_input_saved def closure(): global iteration, psnr_HR if reg_noise_std > 0: net_input.data = net_input_saved + ( torch.cuda.FloatTensor(noise.size()).normal_() * reg_noise_std) out_HR = net(net_input) out_LR = downsampler(out_HR) total_loss = criterion(out_LR, img_LR_var) if tv_weight > 0: total_loss += tv_weight * tv_loss(out_HR) total_loss.backward() # Log psnr_LR = compare_psnr(imgs['LR_np'], out_LR.data.cpu().numpy()[0]) psnr_HR = compare_psnr(imgs['HR_np'], out_HR.data.cpu().numpy()[0]) print('Iteration %05d Loss %3f PSNR_LR %.3f PSNR_HR %.3f' % (iteration, total_loss.data[0], psnr_LR, psnr_HR), '\r', end='') if iteration % 500 == 0 or iteration == args.epoch - 1: out_HR_np = var_to_np(out_HR) # plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], np.clip(out_HR_np, 0, 1)], factor=13, nrow=3) io.imsave( 'sr_' + str(iteration) + '.png', np.transpose(np.clip(out_HR_np, 0, 1), [1, 2, 0])[:, :, :3]) iteration += 1 return total_loss # Optimize (Adam with first 100 epoch and LBFGS with rest) p = get_params('net', net, net_input) optimizer = torch.optim.Adam(net.parameters(), lr=0.001) for j in range(100): optimizer.zero_grad() closure() optimizer.step() print('Starting optimization with LBFGS') if args.epoch > 100: optimizer = torch.optim.LBFGS(net.parameters(), max_iter=args.epoch - 100, lr=0.01, tolerance_grad=-1, tolerance_change=-1) def closure2(): optimizer.zero_grad() return closure() optimizer.step(closure2) # Show final result out_HR_np = np.clip(var_to_np(net(net_input)), 0, 1) result_deep_prior = put_in_center(out_HR_np, imgs['HR_np'].shape[1:]) plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], out_HR_np], factor=4, nrow=1) print('\nFinal PSNR: ', psnr_HR)
from models.model_sr import Model net = Model() net = net.type(dtype) net_input = get_noise( args.input_depth, args.noise_method, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach() mse = torch.nn.MSELoss().type(dtype) img_LR_var = np_to_torch(imgs['LR_np']).type(dtype) downsampler = Downsampler(n_planes=3, factor=args.factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).type(dtype) psnr_gt_best = 0 i = 0 PSNR_list = [] _t = {'im_detect': Timer(), 'misc': Timer()} def closure(): global i, net_input, psnr_gt_best, PSNR_list _t['im_detect'].tic()
def __init__(self, config): self.rank, self.world_size = 0, 1 if config['dist']: self.rank = dist.get_rank() self.world_size = dist.get_world_size() self.config = config self.mode = config['dgp_mode'] self.update_G = config['update_G'] self.update_embed = config['update_embed'] self.iterations = config['iterations'] self.ftr_num = config['ftr_num'] self.ft_num = config['ft_num'] self.lr_ratio = config['lr_ratio'] self.G_lrs = config['G_lrs'] self.z_lrs = config['z_lrs'] self.use_in = config['use_in'] self.select_num = config['select_num'] self.factor = 2 if self.mode == 'hybrid' else 4 # Downsample factor # create model self.G = models.Generator(**config).cuda() self.D = models.Discriminator( **config).cuda() if config['ftr_type'] == 'Discriminator' else None self.G.optim = torch.optim.Adam( [{ 'params': self.G.get_params(i, self.update_embed) } for i in range(len(self.G.blocks) + 1)], lr=config['G_lr'], betas=(config['G_B1'], config['G_B2']), weight_decay=0, eps=1e-8) # load weights if config['random_G']: self.random_G() else: utils.load_weights(self.G if not (config['use_ema']) else None, self.D, config['weights_root'], name_suffix=config['load_weights'], G_ema=self.G if config['use_ema'] else None, strict=False) self.G.eval() if self.D is not None: self.D.eval() self.G_weight = deepcopy(self.G.state_dict()) # prepare latent variable and optimizer self._prepare_latent() # prepare learning rate scheduler self.G_scheduler = utils.LRScheduler(self.G.optim, config['warm_up']) self.z_scheduler = utils.LRScheduler(self.z_optim, config['warm_up']) # loss functions self.mse = torch.nn.MSELoss() if config['ftr_type'] == 'Discriminator': self.ftr_net = self.D self.criterion = utils.DiscriminatorLoss( ftr_num=config['ftr_num'][0]) else: vgg = torchvision.models.vgg16(pretrained=True).cuda().eval() self.ftr_net = models.subsequence(vgg.features, last_layer='20') self.criterion = utils.PerceptLoss() # Downsampler for producing low-resolution image self.downsampler = Downsampler(n_planes=3, factor=self.factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).type( torch.cuda.FloatTensor)
def get_h(n_ch, blur_type, use_fourier, dtype): assert blur_type in ['uniform_blur', 'gauss_blur'], "blur_type can be or 'uniform' or 'gauss'" if not use_fourier: return Downsampler(n_ch, 1, blur_type, preserve_size=True).type(dtype) return lambda im: torch_blur(im, blur_type, dtype)