def my_next_crop(self, for_g): size_of_crop = self.g_input_shape if for_g else self.d_input_shape cropped_image = torchvision.transforms.RandomCrop(size_of_crop)(im2tensor(self.input_image)) if not for_g: cropped_image += im2tensor(np.random.randn(size_of_crop, size_of_crop, 3) / 255.) return cropped_image
def __getitem__(self, idx): """Get a crop for both G and D """ g_in = self.next_crop(for_g=True, idx=idx) d_in = self.next_crop(for_g=False, idx=idx) d_bq = imresize(im=d_in, scale_factor=int(1 / self.conf.scale_factor_downsampler), kernel='cubic') return { 'HR': im2tensor(g_in).squeeze(), 'LR': im2tensor(d_in).squeeze(), 'LR_bicubic': im2tensor(d_bq).squeeze() }
def next_crop(self, for_g, idx): """Return a crop according to the pre-determined list of indices. Noise is added to crops for D""" size = self.g_input_shape if for_g else self.d_input_shape top, left = self.get_top_left(size, for_g, idx) crop_im = self.input_image[top:top + size, left:left + size, :].copy() if not for_g: # Add noise to the image for d crop_im += np.random.randn(*crop_im.shape) / 255.0 return im2tensor(crop_im)
def __init__(self, conf, gan): # Default shapes self.g_input_shape = conf.input_crop_size self.d_input_shape = gan.G.output_size # shape entering D downscaled by G # Read input image self.input_image = read_image(conf.input_image_path) / 255. self.sr_noise = im2tensor(np.random.rand(*self.input_image.shape)) self.lr_noise = im2tensor(np.random.rand(*self.input_image.shape)) # self.shave_edges(scale_factor=conf.scale_factor, real_image=conf.real_image) self.in_rows, self.in_cols = self.input_image.shape[0:2] # Create prob map for choosing the crop self.crop_indices_for_g, self.crop_indices_for_d = self.make_list_of_crop_indices( conf=conf)
def finish(self, image): with torch.no_grad(): image = im2tensor(image) sr = self.U(image) if self.conf.X4: sr = im2tensor(tensor2im(sr)) sr = self.U(sr) sr = tensor2im(sr) def save_np_as_img(arr, path): Image.fromarray(np.uint8(arr)).save(path) save_np_as_img( sr, os.path.join(self.conf.output_dir_path, 'image sr.png')) print('FINISHED RUN (see --%s-- folder)\n' % self.conf.output_dir_path + '*' * 60 + '\n\n')
def __init__(self, conf): # Fix random seed torch.manual_seed(0) torch.backends.cudnn.deterministic = True # slightly reduces throughput # Acquire configuration self.conf = conf # Define the networks self.G_DN = networks.Generator_DN().cuda() self.D_DN = networks.Discriminator_DN().cuda() self.G_UP = networks.Generator_UP().cuda() # Losses self.criterion_gan = loss.GANLoss().cuda() self.criterion_cycle = torch.nn.L1Loss() self.criterion_interp = torch.nn.L1Loss() self.regularization = loss.DownsamplerRegularization( conf.scale_factor_downsampler, self.G_DN.G_kernel_size) # Initialize networks weights self.G_DN.apply(networks.weights_init_G_DN) self.D_DN.apply(networks.weights_init_D_DN) self.G_UP.apply(networks.weights_init_G_UP) # Optimizers self.optimizer_G_DN = torch.optim.Adam(self.G_DN.parameters(), lr=conf.lr_G_DN, betas=(conf.beta1, 0.999)) self.optimizer_D_DN = torch.optim.Adam(self.D_DN.parameters(), lr=conf.lr_D_DN, betas=(conf.beta1, 0.999)) self.optimizer_G_UP = torch.optim.Adam(self.G_UP.parameters(), lr=conf.lr_G_UP, betas=(conf.beta1, 0.999)) # Read input image self.in_img = util.read_image(conf.input_image_path) self.in_img_t = util.im2tensor(self.in_img) b_x = self.in_img_t.shape[2] % conf.scale_factor b_y = self.in_img_t.shape[3] % conf.scale_factor self.in_img_cropped_t = self.in_img_t[..., b_x:, b_y:] self.gt_img = util.read_image( conf.gt_path) if conf.gt_path is not None else None self.gt_kernel = loadmat( conf.kernel_path )['Kernel'] if conf.kernel_path is not None else None if self.gt_kernel is not None: self.gt_kernel = np.pad(self.gt_kernel, 1, 'constant') self.gt_kernel = util.kernel_shift(self.gt_kernel, sf=conf.scale_factor) self.gt_kernel_t = torch.FloatTensor(self.gt_kernel).cuda() self.gt_downsampled_img_t = util.downscale_with_kernel( self.in_img_cropped_t, self.gt_kernel_t) self.gt_downsampled_img = util.tensor2im(self.gt_downsampled_img_t) # Debug variables self.debug_steps = [] self.UP_psnrs = [] if self.gt_img is not None else None self.DN_psnrs = [] if self.gt_kernel is not None else None if self.conf.debug: self.loss_GANs = [] self.loss_cycle_forwards = [] self.loss_cycle_backwards = [] self.loss_interps = [] self.loss_Discriminators = [] self.iter = 0