def test_image(gt_path, noisy_path, model_path, transform, show = True, psnr = True, ssim = True): # model load model = ridnet.RIDNET(args) model.load_state_dict(torch.load(model_path)) model.eval() # open noisy image try : noisy_image = PIL.Image.open(noisy_path) if not gt_path is None: gt_image = PIL.Image.open(gt_path) except : noisy_image = noisy_path gt_image = gt_path state = torch.get_rng_state() transformed_noisy = transform(noisy_image) torch.set_rng_state(state) if not gt_path is None: gt_image = transform(gt_image).permute(1,2,0).numpy() transformed_noisy = transformed_noisy.unsqueeze(0) predicted_image = pass_though_net(model,transformed_noisy) #at this point this is np # loss = nn.L1Loss() # print(loss(gt_image,predicted_image)) psnr_val =0 if psnr and not gt_path is None: psnr_val = PSNR(predicted_image, gt_image) ssim_val = 0 if ssim and not gt_path is None: ssim_val = calc_ssim(predicted_image,gt_image) if show: cv2.imshow("noisy image", transformed_noisy.squeeze(0).permute(1,2,0).numpy()) cv2.imshow("predicted image", predicted_image) if gt_path != None: cv2.imshow("gt image", gt_image) cv2.waitKey(0) return psnr_val, ssim_val
def main(): # # get_image_from_mat_db('../Nam/mat/Canon_EOS_5D_Mark3/ISO_3200/C_2.mat') # # return 0 # # model = ridnet.RIDNET(args) # model_path = '../models/mssim_61121.pt' # model.load_state_dict(torch.load(model_path)) # transform = transforms.Compose([ # transforms.ToTensor() # ]) # test_and_save("../test_im/dog_im/dog_rni15_2.png",model,"../test_im/dog_im/dog_rni15_msssim.png",transform) # print("ok ok ok ") # return 0 # image = cv2.imread("../test_im/17/noisy.png") # inage1 = bgr2rgb(image) # cv2.imwrite("../test_im/17/noisy2.png",inage1) # return 0 # print(os.listdir()) # homepath = '../' # return 0 # nam_path = 'PycharmProjects/Data/Nam/Nam_mat/Nikon_D800/ISO_6400/B_3.mat' # nam_mat = loadmat(nam_path) # nam_name = (nam_path.split('/')[-2] +'_' + nam_path.split('/')[-1]).replace('mat','PNG') # print(nam_name) # print(nam_mat.keys()) # print(nam_mat['img_mean'].shape) # print(nam_mat['img_cov'].shape) # print(nam_mat['img_noisy'].shape) # im1 = nam_mat['img_cov'].astype('uint8') # im2 = nam_mat['img_mean'].astype('uint8') # print(type(im1)) # rgb_im = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB) # cv2.imwrite('PycharmProjects/Data/Nam/Nam_data/D800/'+ nam_name,rgb_im) # # cv2.waitKey(0) # return 0 # num_img, num_blocks, _, _, _ = nam_mat['ValidationNoisyBlocksSrgb'].shape # make_DB_of_mat('../mats/ValidationNoisyBlocksSrgb.mat','../test_set/noisy/') # image = cv2.imread('../test_im/165/gt.png') # image1 = bgr2rgb(image) # cv2.imwrite('../test_im/165/gt1.png',image1) # return 0 # gt_im = cv2.imread('./Patches/0035_002_GP_00800_00350_3200_N/GT_22_22_010.PNG') # gt_im = cv2.imread('../../../Data/SIDD_medium/SSID_medium/SIDD_Medium_Srgb/Data/0165_007_IP_00800_00800_3200_N/0165_GT_SRGB_011.PNG') # # cv2.imshow("im",gt_im) # # cv2.waitKey() # # gt_im = cv2.cvtColor(gt_im,cv2.COLOR_BGR2RGB) # noisy_im = cv2.imread('../../../Data/SIDD_medium/SSID_medium/SIDD_Medium_Srgb/Data/0165_007_IP_00800_00800_3200_N/0165_NOISY_SRGB_011.PNG') # model1 = ridnet.RIDNET(args) # model1_path = '../models/L1_128p_171021.pt' # model1.load_state_dict(torch.load(model1_path)) # model1.eval() # # model2 = ridnet.RIDNET(args) # model2_path = '../models/LabL1_128p_111021.pt' # model2.load_state_dict(torch.load(model2_path)) # model2.eval() # # model3 = ridnet.RIDNET(args) # model3_path = '../models/25621_l1.pt' # model3.load_state_dict(torch.load(model3_path)) # model3.eval() # # model4 = ridnet.RIDNET(args) # model4_path = '../models/LabLoss_17921_l2.pt' # model4.load_state_dict(torch.load(model4_path)) # model4.eval() # # model5 = ridnet.RIDNET(args) # model5_path = '../models/LabLoss_13921_l1.pt' # model5.load_state_dict(torch.load(model5_path)) # model5.eval() # # model6 = ridnet.RIDNET(args) # model6_path = '../models/27621_l2.pt' # model6.load_state_dict(torch.load(model6_path)) # model6.eval() # # model7 = ridnet.RIDNET(args) # model7_path = '../models/ContentLoss_Lab_27921_l2.pt' # model7.load_state_dict(torch.load(model7_path)) # model7.eval() # model8 = ridnet.RIDNET(args) model8_path = '../models/LabL1_syn_11522_fullset_and_halfset.pt' model8.load_state_dict(torch.load(model8_path)) model8.eval() model9 = ridnet.RIDNET(args) model9_path = '../models/LabLoss_13921_l1.pt' model9.load_state_dict(torch.load(model9_path)) model9.eval() model10 = ridnet.RIDNET(args) model10_path = '../models/LabL1_syn_07522_final.pt' model10.load_state_dict(torch.load(model10_path)) model10.eval() # models = [model1,model2,model3,model4,model5,model6,model7,model8,model9] # models_names = ["L1_128","L1Lab_128","L1","LabL2","LabL1","L2","ContentLossLab","Y_L1","msssim"] models = [model8,model9,model10] models_names = ["Lab_syn_fullset","LabL1","Lab_syn"] # m1_model = ridnet.RIDNET(args) # m1_model_path = '../models/LabL1_128p_191021_final.pt' # m1_model.load_state_dict(torch.load(m1_model_path)) # m1_model.eval() # # ms_l1_model = ridnet.RIDNET(args) # ms_l1_model_path = '../models/msssim_L1_91121_final.pt' # ms_l1_model.load_state_dict(torch.load(ms_l1_model_path)) # ms_l1_model.eval() # models = [m1_model,ms_l1_model] # models_names = ["L1128_6_epoch","msssim__128p_l1_3_epoch"] transform = transforms.Compose([ transforms.ToTensor(), transforms.RandomCrop(1000)]) # # transforms.RandomHorizontalFlip() # # SIDD_Dataset.rotate_by_90_mul([0, 90, 180, 270]) # ]) # transform = transforms.Compose([ # transforms.ToTensor(), # # transforms.RandomHorizontalFlip() # # SIDD_Dataset.rotate_by_90_mul([0, 90, 180, 270]) # ]) # gt_im,noisy_im = create_cropped_images(gt_im,noisy_im,transform) # gt_im,noisy_im = load_im("../test_im/"+"17/",transform) # compare_save_images(gt_im,noisy_im,'../test_im/17/',models,models_names,"17",transform) # # gt_im, noisy_im = load_im("../big_images/chess/", transform)#image names should be gt.png and noisy.png. you can control crop with transforms # print* # print("Noisy Image PSNR is: " + str(PSNR(cv2.cvtColor(cv2.UMat(noisy_im),cv2.COLOR_BGR2RGB), gt_im)) + " SSIM is: " + str(calc_ssim(cv2.cvtColor(noisy_im,cv2.COLOR_BGR2RGB),gt_im))) # compare_save_images(gt_im,noisy_im,"../big_images/chess/",models,models_names,"chess_im_161221",transform1) '' # gt_im, noisy_im = load_im("../../../data/presnt_im/" + "39/", transform,True) # compare_save_images(gt_im,noisy_im,"../../../data/presnt_im/11522/",models,models_names,"39",transform) # # gt_im, noisy_im = load_im("../../../data/presnt_im/" + "111/", transform,True) # compare_save_images(gt_im,noisy_im,"../../../data/presnt_im/11522/" ,models,models_names,"111",transform) # # gt_im, noisy_im = load_im("../../../data/presnt_im/" + "150/", transform,True) # compare_save_images(gt_im, noisy_im, "../../../data/presnt_im/11522/", models, models_names, "150", transform) gt_im, noisy_im = load_nam_im("../nam_images/",str(4) , transform,True) compare_save_images(gt_im,noisy_im,"../nam_images/results/",models,models_names,"4",transform) gt_im, noisy_im = load_nam_im("../nam_images/",str(11), transform,True) compare_save_images(gt_im,noisy_im,"../nam_images/results/" ,models,models_names,"11",transform) gt_im, noisy_im = load_nam_im("../nam_images/",str(24), transform,True) compare_save_images(gt_im, noisy_im, "../nam_images/results/", models, models_names, "24", transform) # gt_im, noisy_im = load_im("../test_im/" + "60/", transform,False) # compare_save_images(gt_im, noisy_im, '../test_im/60/', models, models_names, "60", transform) # # gt_im, noisy_im = load_im("../test_im/" + "165/", transform,False) # compare_save_images(gt_im, noisy_im, '../test_im/165/', models, models_names, "165", transform) return 0 transform = transforms.Compose([ transforms.ToTensor() # transforms.RandomHorizontalFlip() # SIDD_Dataset.rotate_by_90_mul([0, 90, 180, 270]) ]) test_on_test_set('../test_set/gt/','../test_set/noisy/','../models/res_model1.pt',test_image,transform) test_on_test_set() # test_image('../test_set/gt/','../test_set/noisy/','../models/LabLoss_6921_l1.pt',transform,show=False,psnr=False,ssim=False) return 0 noisy_path = 'PycharmProjects/Data/presentation/ISO_3200_C_2_cr.png'#'../test/0046_002_G4_00400_00350_3200_L/0046_NOISY_SRGB_010.PNG' gt_path = noisy_path.replace('NOISY', 'GT') # transform = transforms.Compose([ transforms.ToTensor(), # transforms.RandomCrop(850) ]) # #model = ridnet.RIDNET(args) model_path = 'PycharmProjects/Dnnproj/models/27621_l2.pt' # psnr, ssim = test_image(None,noisy_path,model_path,transform, show=False) print("psnr is: "+ str(psnr) + " ssim is: "+ str(ssim)) return 0
# cv2.waitKey(0) # print(os.getcwd()) # dog_path = '../../../PycharmProjects/Data/presentation/dog_rni15_2.png' # dog2_path = '../../../PycharmProjects/Data/presentation/pred/presentation_dog_rni15_2.png' train_dataloders = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2) validation_dataloders = torch.utils.data.DataLoader(validation_dataset, batch_size=16, shuffle=True, num_workers=2) # cv2.imshow("loaders",train_dataloders) model = ridnet.RIDNET(args) # num_trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad]) # print("the number of trainable weights is: ", num_trainable_params) # model.load_state_dict(torch.load(model_path+'Y_L1_241021_final.pt')) # dog_im = Image.open(dog_path) # dog_im = transform(dog_im) # dog2_im = Image.open(dog2_path) # dog2_im = transform(dog2_im) # dog2_im = dog2_im.unsqueeze(0) # dog_im = dog_im.unsqueeze(0) # dogs = torch.cat((dog_im,dog2_im),dim=0) # lab_dog = utility.rgb2xyz(dogs) # print(lab_dog.shape) # # print(lab_dog)