def test_ssim_single_ch_identical(): """Test torch implementation of the SSIM (structural similarity) metric for grayscale picture Check that comparing identical pictures returns 1 """ ssim = SSIM(data_range=255, channels=1) images = _load_test_images() for image in images: image = color.rgb2gray(image) # add dimensions for the channel and the batch dimension im_tensor = torch.Tensor(image).unsqueeze(0).unsqueeze(0) result = ssim(im_tensor, im_tensor).numpy() np.testing.assert_allclose(result, 1., rtol=RTOL)
def test_ssim_single_channel(): """Test torch implementation of the SSIM (structural similarity) metric for grayscale picture""" ssim = SSIM(data_range=255, channels=1) images = _load_test_images() for image in images: image = color.rgb2gray(image) image_c = _corrupt_image(image) # add dimensions for the channel and the batch dimension im_tensor = torch.Tensor(image).unsqueeze(0).unsqueeze(0) im_tensor_c = torch.Tensor(image_c).unsqueeze(0).unsqueeze(0) result = ssim(im_tensor, im_tensor_c).numpy() desired = _groundtruth_ssim(image, image_c, multichannel=False) np.testing.assert_allclose(result, desired, rtol=RTOL)
def test_ssim_multi_ch_identical(): """Test torch implementation of the SSIM (structural similarity) metric for color picture Check that comparing identical pictures returns 1 """ ssim = SSIM(data_range=255, channels=4) images = _load_test_images() for image in images: # the transpose is necessary to get the structure NCHW instead of NHWC im_tensor = torch.Tensor(image).transpose(2, 1).transpose(1, 0).unsqueeze(0) result = ssim(im_tensor, im_tensor).numpy() np.testing.assert_allclose(result, 1., rtol=RTOL)
def test_ssim_multi_channel(): """Test torch implementation of the SSIM (structural similarity) metric for color picture""" ssim = SSIM(data_range=255, channels=4) images = _load_test_images() for image in images: image_c = _corrupt_image(image) # the transpose is necessary to get the structure NCHW instead of NHWC im_tensor = torch.Tensor(image).transpose(2, 1).transpose(1, 0).unsqueeze(0) im_tensor_c = torch.Tensor(image_c).transpose(2, 1).transpose( 1, 0).unsqueeze(0) result = ssim(im_tensor, im_tensor_c).numpy() desired = _groundtruth_ssim(image, image_c, multichannel=True) np.testing.assert_allclose(result, desired, rtol=RTOL)
def validate(self, val_batch, current_step): avg_psnr = 0.0 avg_ssim = 0.0 idx = 0 for _, val_data in enumerate(val_batch): idx += 1 img_name = os.path.splitext( os.path.basename(val_data['LR_path'][0]))[0] img_dir = os.path.join( self.opt['path']['checkpoints']['val_image_dir'], img_name) util.mkdir(img_dir) self.val_lr = val_data['LR'].to(self.device) self.val_hr = val_data['HR'].to(self.device) self.G.eval() with torch.no_grad(): self.val_sr = self.G(self.val_lr) self.G.train() val_LR = self.val_lr.detach()[0].float().cpu() val_SR = self.val_sr.detach()[0].float().cpu() val_HR = self.val_hr.detach()[0].float().cpu() sr_img = util.tensor2img(val_SR) # uint8 gt_img = util.tensor2img(val_HR) # uint8 # Save SR images for reference save_img_path = os.path.join( img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) cv2.imwrite(save_img_path, sr_img) # calculate PSNR crop_size = 4 gt_img = gt_img / 255. sr_img = sr_img / 255. cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] avg_psnr += PSNR(cropped_sr_img * 255, cropped_gt_img * 255) avg_ssim += SSIM(cropped_sr_img * 255, cropped_gt_img * 255) avg_psnr = avg_psnr / idx avg_ssim = avg_ssim / idx return avg_psnr, avg_ssim
def test(args, model, test_dataloader): PSNR_total = [] SSIM_total = [] #model.eval() print('=====> test sr begin!') with torch.no_grad(): for i, data in enumerate(test_dataloader): #torch.Size([1, 3, 320, 320]) img_ref = data['image_center'] img_oth = data['image_others'] #img_oth = torch.squeeze(img_oth) img_adv_cen = data['img_adv_cen'] img_adv_ref = data['img_adv_ref'] img_oth = img_oth.squeeze(0) img_adv_ref = img_adv_ref.squeeze(0) img_ref = img_ref.expand(args.batch_size, -1, -1, -1) img_adv_cen = img_adv_cen.expand(args.batch_size, -1, -1, -1) image_others = (img_oth.cuda())[:, :, :args.im_crop_H, :args. im_crop_W].clone().float() #print(img_ref.shape) #image_ref = img_ref.expand(args.batch_size-1, -1, -1, -1) #image_ref = (image_ref.cuda())[:, :, :args.im_crop_H, :args.im_crop_W].clone().float() image_ref = (img_ref.cuda())[:, :, :args.im_crop_H, :args. im_crop_W].clone().float() lr_image_ref = nn.functional.avg_pool2d(image_ref, kernel_size=args.scale) lr_image_others = nn.functional.avg_pool2d(image_others, kernel_size=args.scale) image_adv_cen = img_adv_cen.cuda().clone().float() image_adv_ref = img_adv_ref.cuda().clone().float() ''' hr_val = model.net_sr(lr_image_ref) hr_ref = model.net_sr(lr_image_others) #flows_ref_to_other = model.net_flow(image_ref, image_others) flows_ref_to_other = model.net_flow(hr_val, hr_ref) #flows_other_to_ref = model.net_flow(image_others, image_ref) #flow_12_1 = flows_ref_to_other[0]*20.0 #flow_12_2 = flows_ref_to_other[1]*10.0 #flow_12_3 = flows_ref_to_other[2]*5.0 #flow_12_4 = flows_ref_to_other[3]*2.5 #SR_conv1, SR_conv2, SR_conv3, SR_conv4 = model.net_enc(hr_val) #HR2_conv1, HR2_conv2, HR2_conv3, HR2_conv4 = model.net_enc(hr_ref) #warp_21_conv1 = model.Backward_warper(HR2_conv1, flow_12_1) #warp_21_conv2 = model.Backward_warper(HR2_conv2, flow_12_2) #warp_21_conv3 = model.Backward_warper(HR2_conv3, flow_12_3) #warp_21_conv4 = model.Backward_warper(HR2_conv4, flow_12_4) #hr_val = model.net_dec(SR_conv1, SR_conv2, SR_conv3, SR_conv4, warp_21_conv1,warp_21_conv2, warp_21_conv3,warp_21_conv4) #hr_val = model.net_G1(hr_val, flows_ref_to_other, model.Backward_warper, image_others) hr_val = model.net_G1(hr_val, flows_ref_to_other, model.Backward_warper, hr_ref) #print(hr_val.min(), hr_val.max()) ''' #hr_val = model.net_sr(lr_image_ref) + model.upsample_4(lr_image_ref) #hr_ref = model.net_sr(lr_image_others) + model.upsample_4(lr_image_others) #flows_ref_to_other = model.net_flow(hr_val, hr_ref) #hr_val = model.net_G1(hr_val, flows_ref_to_other, model.Backward_warper, hr_ref) #noise = torch.randn(args.batch_size, args.n_colors, args.im_crop_H, args.im_crop_W).cuda() * 1e-4 #hr_val = model.net_sr(image_ref) #hr_val = model.net_sr(image_ref) + model.upsample_4(image_ref) #hr_val = model.net_sr(image_ref) #hr_val = model.net_G1(hr_val) #hr_val = model.net_G2(image_adv_cen) #res = model.net_G(hr_val) #res = model.net_G1(hr_val) #hr_val = hr_val + res #hr_other_imgs = self.net_sr(lr_other_imgs) #hr_val = model.net_sr(lr_image_ref) #noise = torch.randn(args.batch_size, args.n_colors, args.im_crop_H, args.im_crop_W).cuda() * 0.0001 #hr_val = hr_val + model.net_G(hr_val) #hr_val = model.net_G1(hr_val) lr_feature_head = model.net_Feature_Head(lr_image_ref) lr_content_feature = model.net_Feature_extractor(lr_feature_head) lr_content_output = lr_feature_head + lr_content_feature hr_val = model.net_Upscalar(lr_content_output) hr_val = model.net_G1(hr_val) hr_val_numpy = hr_val.cpu()[0].permute(1, 2, 0).numpy() hr_val_numpy[hr_val_numpy > 1] = 1 hr_val_numpy[hr_val_numpy < -1] = -1 img_sr = skimage.img_as_ubyte(hr_val_numpy) skimage.io.imsave( os.path.join(args.result_dir, 'tempo', 'SR_{}.png'.format(i)), img_sr) #skimage.io.imsave(os.path.join(args.result_dir, 'tempo', 'SR_{}.png'.format(i)), hr_val_numpy) if args.have_gt: PSNR_value = PSNR(hr_val.data, image_ref) SSIM_value = SSIM(hr_val.data, image_ref) PSNR_total.append(PSNR_value) SSIM_total.append(SSIM_value) print('PSNR: {} for patch {}'.format(PSNR_value, i)) print('SSIM: {} for patch {}'.format(SSIM_value, i)) print('Average PSNR: {} for {} patches'.format( sum(PSNR_total) / len(PSNR_total), i)) print('Average SSIM: {} for {} patches'.format( sum(SSIM_total) / len(SSIM_total), i)) if args.save_result: os.makedirs(os.path.join(args.result_dir, 'HR'), exist_ok=True) os.makedirs(os.path.join(args.result_dir, 'LR'), exist_ok=True) os.makedirs(os.path.join(args.result_dir, 'REF'), exist_ok=True) os.makedirs(os.path.join(args.result_dir, 'ADV_CEN'), exist_ok=True) os.makedirs(os.path.join(args.result_dir, 'ADV_REF'), exist_ok=True) #img_gt = skimage.img_as_float(torch.squeeze(img_ref).permute(1,2,0).numpy()) img_gt = skimage.img_as_ubyte( torch.squeeze(img_ref).permute(1, 2, 0).numpy()) skimage.io.imsave( os.path.join(args.result_dir, 'HR', '{}.png'.format(i)), img_gt) skimage.io.imsave( os.path.join(args.result_dir, 'HR', '{}.png'.format(i)), img_gt) img_lr = skimage.img_as_ubyte(lr_image_ref.cpu()[0].permute( 1, 2, 0).numpy()) skimage.io.imsave( os.path.join(args.result_dir, 'LR', '{}.png'.format(i)), img_lr) skimage.io.imsave( os.path.join(args.result_dir, 'LR', '{}.png'.format(i)), img_lr) img_adv_center = skimage.img_as_ubyte( image_adv_cen.cpu()[0].permute(1, 2, 0).numpy()) skimage.io.imsave( os.path.join(args.result_dir, 'ADV_CEN', '{}.png'.format(i)), img_adv_center) for j in range(args.batch_size): os.makedirs(os.path.join(args.result_dir, 'ADV_REF', '{}'.format(j)), exist_ok=True) img_adv_reference = skimage.img_as_ubyte( image_adv_ref.cpu()[j].permute(1, 2, 0).numpy()) skimage.io.imsave( os.path.join(args.result_dir, 'ADV_REF', '{}'.format(j), '{}.png'.format(i)), img_adv_reference) os.makedirs(os.path.join(args.result_dir, 'REF', '{}'.format(j)), exist_ok=True) img_reference = skimage.img_as_ubyte( image_others.cpu()[j].permute(1, 2, 0).numpy()) skimage.io.imsave( os.path.join(args.result_dir, 'REF', '{}'.format(j), '{}.png'.format(i)), img_reference)
def test_lr(args, model, test_dataloader): #model.eval() print('=====> test existing lr begin!') PSNR_total = [] SSIM_total = [] fake_total = [] real_total = [] Loss_function = GANLoss() with torch.no_grad(): for i, data in enumerate(test_dataloader): img_lr = data['lr_image'] img_lr = img_lr.expand(args.batch_size, -1, -1, -1) img_lr = img_lr.cuda().clone().float() #hr_val = model.net_sr(img_lr) #flows_ref_to_other = model.net_flow(self.hr_img_ref_gt, self.hr_img_oth_gt) #flows_other_to_ref = model.net_flow(self.hr_img_oth_gt, self.hr_img_ref_gt) #flow_12_1 = self.flows_ref_to_other[0]*20.0 #flow_12_2 = self.flows_ref_to_other[1]*10.0 #flow_12_3 = self.flows_ref_to_other[2]*5.0 #flow_12_4 = self.flows_ref_to_other[3]*2.5 #SR_conv1, SR_conv2, SR_conv3, SR_conv4 = self.net_enc(self.sr_img_ref) #HR2_conv1, HR2_conv2, HR2_conv3, HR2_conv4 = self.net_enc(self.hr_img_oth_gt) #warp_21_conv1 = self.Backward_warper(HR2_conv1, flow_12_1) #warp_21_conv2 = self.Backward_warper(HR2_conv2, flow_12_2) #warp_21_conv3 = self.Backward_warper(HR2_conv3, flow_12_3) #warp_21_conv4 = self.Backward_warper(HR2_conv4, flow_12_4) #sythsis_output = self.net_dec(SR_conv1, SR_conv2, SR_conv3, SR_conv4, warp_21_conv1,warp_21_conv2, warp_21_conv3,warp_21_conv4) #lr_feature_head = model.net_Feature_Head(img_lr) #lr_content_feature = model.net_Feature_extractor(lr_feature_head) #lr_content_output = lr_feature_head + lr_content_feature #hr_val = model.net_Upscalar(lr_content_output) hr_val = model.net_sr(img_lr) + model.upsample_4(img_lr) #hr_val = model.upsample_4(img_lr) #hr_val = model.net_sr(img_lr) #noise = torch.randn(args.batch_size, args.n_colors, args.im_crop_H, args.im_crop_W).cuda() * 1e-4 #hr_val = hr_val + model.net_G1(hr_val) hr_val = model.net_G1(hr_val) #hr_val = model.net_G1(hr_val) #hr_val = model.net_G2(hr_val) #m = nn.Upsample(size=[args.im_crop_H*3, args.im_crop_W*3],mode='bilinear',align_corners=True) #hr_val = m(hr_val) hr_val_numpy = hr_val.cpu()[0].permute(1, 2, 0).numpy() hr_val_numpy[hr_val_numpy > 1] = 1 hr_val_numpy[hr_val_numpy < -1] = -1 img_sr = skimage.img_as_ubyte(hr_val_numpy) skimage.io.imsave( os.path.join(args.result_dir, 'SR', 'SR_{}.png'.format(i)), img_sr) #skimage.io.imsave(os.path.join(args.result_dir, 'SR_{}.png'.format(i)), img_sr) #dx_hr_img_fake, dy_hr_img_fake, dxy_hr_img_fake = model.gradient_fn(hr_val) #hr_img_fake = torch.cat([dx_hr_img_fake, dy_hr_img_fake, dxy_hr_img_fake], dim=0) #fake = model.net_D(hr_img_fake) #fake = Loss_function(fake, target_is_real=False) #print('fake: {} for patch {}'.format(fake, i)) #fake_total.append(fake) #print('Average fake: {} for {} patches'.format(sum(fake_total) / len(fake_total), i)) if args.have_gt: img_hr = data['hr_image'] img_hr = img_hr.expand(args.batch_size, -1, -1, -1) img_hr = img_hr.cuda().clone().float() #dx_hr_img_real, dy_hr_img_real, dxy_hr_img_real = model.gradient_fn(img_hr) #hr_img_real = torch.cat([dx_hr_img_real, dy_hr_img_real, dxy_hr_img_real], dim=0) #real = model.net_D(hr_img_real) #real = Loss_function(real, target_is_real=True) #print('real: {} for patch {}'.format(real, i)) #real_total.append(real) #print('Average real: {} for {} patches'.format(sum(real_total) / len(real_total), i)) PSNR_value = PSNR(hr_val.data, img_hr) SSIM_value = SSIM(hr_val.data, img_hr) PSNR_total.append(PSNR_value) SSIM_total.append(SSIM_value) print('PSNR: {} for patch {}'.format(PSNR_value, i)) print('SSIM: {} for patch {}'.format(SSIM_value, i)) print('Average PSNR: {} for {} patches'.format( sum(PSNR_total) / len(PSNR_total), i)) print('Average SSIM: {} for {} patches'.format( sum(SSIM_total) / len(SSIM_total), i))