Ejemplo n.º 1
0
def test_image(gt_path, noisy_path, model_path, transform, show = True, psnr = True, ssim = True):
    # model load
    model = ridnet.RIDNET(args)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    # open noisy image
    try :
        noisy_image = PIL.Image.open(noisy_path)
        if not gt_path is None:
            gt_image = PIL.Image.open(gt_path)
    except :
        noisy_image = noisy_path
        gt_image = gt_path
    state = torch.get_rng_state()
    transformed_noisy = transform(noisy_image)

    torch.set_rng_state(state)
    if not gt_path is None:
        gt_image = transform(gt_image).permute(1,2,0).numpy()

    transformed_noisy = transformed_noisy.unsqueeze(0)

    predicted_image = pass_though_net(model,transformed_noisy)

    #at this point this is np



    # loss = nn.L1Loss()
    # print(loss(gt_image,predicted_image))
    psnr_val =0
    if psnr and not gt_path is None:
        psnr_val = PSNR(predicted_image, gt_image)
    ssim_val = 0
    if ssim and not gt_path is None:
        ssim_val = calc_ssim(predicted_image,gt_image)
    if show:
        cv2.imshow("noisy image", transformed_noisy.squeeze(0).permute(1,2,0).numpy())
        cv2.imshow("predicted image", predicted_image)
        if gt_path != None:
            cv2.imshow("gt image", gt_image)
        cv2.waitKey(0)

    return psnr_val, ssim_val
Ejemplo n.º 2
0
def main():
#     # get_image_from_mat_db('../Nam/mat/Canon_EOS_5D_Mark3/ISO_3200/C_2.mat')
#     # return 0
#
#     model = ridnet.RIDNET(args)
#     model_path = '../models/mssim_61121.pt'
#     model.load_state_dict(torch.load(model_path))
#     transform = transforms.Compose([
#         transforms.ToTensor()
#     ])
#     test_and_save("../test_im/dog_im/dog_rni15_2.png",model,"../test_im/dog_im/dog_rni15_msssim.png",transform)
#     print("ok ok ok ")
#     return 0



    # image = cv2.imread("../test_im/17/noisy.png")
    # inage1 = bgr2rgb(image)
    # cv2.imwrite("../test_im/17/noisy2.png",inage1)
    # return 0
    # print(os.listdir())
    # homepath = '../'
    # return 0
    # nam_path = 'PycharmProjects/Data/Nam/Nam_mat/Nikon_D800/ISO_6400/B_3.mat'
    # nam_mat = loadmat(nam_path)
    # nam_name = (nam_path.split('/')[-2] +'_' + nam_path.split('/')[-1]).replace('mat','PNG')
    # print(nam_name)
    # print(nam_mat.keys())
    # print(nam_mat['img_mean'].shape)
    # print(nam_mat['img_cov'].shape)
    # print(nam_mat['img_noisy'].shape)
    # im1 = nam_mat['img_cov'].astype('uint8')
    # im2 = nam_mat['img_mean'].astype('uint8')
    # print(type(im1))
    # rgb_im = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB)
    # cv2.imwrite('PycharmProjects/Data/Nam/Nam_data/D800/'+ nam_name,rgb_im)
    # # cv2.waitKey(0)
    # return 0
    # num_img, num_blocks, _, _, _ = nam_mat['ValidationNoisyBlocksSrgb'].shape

    # make_DB_of_mat('../mats/ValidationNoisyBlocksSrgb.mat','../test_set/noisy/')


    # image = cv2.imread('../test_im/165/gt.png')
    # image1 = bgr2rgb(image)
    # cv2.imwrite('../test_im/165/gt1.png',image1)
    # return 0

    # gt_im = cv2.imread('./Patches/0035_002_GP_00800_00350_3200_N/GT_22_22_010.PNG')
    # gt_im = cv2.imread('../../../Data/SIDD_medium/SSID_medium/SIDD_Medium_Srgb/Data/0165_007_IP_00800_00800_3200_N/0165_GT_SRGB_011.PNG')
    # # cv2.imshow("im",gt_im)
    # # cv2.waitKey()
    # # gt_im = cv2.cvtColor(gt_im,cv2.COLOR_BGR2RGB)
    # noisy_im = cv2.imread('../../../Data/SIDD_medium/SSID_medium/SIDD_Medium_Srgb/Data/0165_007_IP_00800_00800_3200_N/0165_NOISY_SRGB_011.PNG')

    # model1 = ridnet.RIDNET(args)
    # model1_path = '../models/L1_128p_171021.pt'
    # model1.load_state_dict(torch.load(model1_path))
    # model1.eval()
    #
    # model2 = ridnet.RIDNET(args)
    # model2_path = '../models/LabL1_128p_111021.pt'
    # model2.load_state_dict(torch.load(model2_path))
    # model2.eval()
    #
    # model3 = ridnet.RIDNET(args)
    # model3_path = '../models/25621_l1.pt'
    # model3.load_state_dict(torch.load(model3_path))
    # model3.eval()
    #
    # model4 = ridnet.RIDNET(args)
    # model4_path = '../models/LabLoss_17921_l2.pt'
    # model4.load_state_dict(torch.load(model4_path))
    # model4.eval()
    #
    # model5 = ridnet.RIDNET(args)
    # model5_path = '../models/LabLoss_13921_l1.pt'
    # model5.load_state_dict(torch.load(model5_path))
    # model5.eval()
    #
    # model6 = ridnet.RIDNET(args)
    # model6_path = '../models/27621_l2.pt'
    # model6.load_state_dict(torch.load(model6_path))
    # model6.eval()
    #
    # model7 = ridnet.RIDNET(args)
    # model7_path = '../models/ContentLoss_Lab_27921_l2.pt'
    # model7.load_state_dict(torch.load(model7_path))
    # model7.eval()
    #
    model8 = ridnet.RIDNET(args)
    model8_path = '../models/LabL1_syn_11522_fullset_and_halfset.pt'
    model8.load_state_dict(torch.load(model8_path))
    model8.eval()

    model9 = ridnet.RIDNET(args)
    model9_path = '../models/LabLoss_13921_l1.pt'
    model9.load_state_dict(torch.load(model9_path))
    model9.eval()

    model10 = ridnet.RIDNET(args)
    model10_path = '../models/LabL1_syn_07522_final.pt'
    model10.load_state_dict(torch.load(model10_path))
    model10.eval()

    # models = [model1,model2,model3,model4,model5,model6,model7,model8,model9]
    # models_names = ["L1_128","L1Lab_128","L1","LabL2","LabL1","L2","ContentLossLab","Y_L1","msssim"]


    models = [model8,model9,model10]
    models_names = ["Lab_syn_fullset","LabL1","Lab_syn"]

    # m1_model = ridnet.RIDNET(args)
    # m1_model_path = '../models/LabL1_128p_191021_final.pt'
    # m1_model.load_state_dict(torch.load(m1_model_path))
    # m1_model.eval()
    #
    # ms_l1_model = ridnet.RIDNET(args)
    # ms_l1_model_path = '../models/msssim_L1_91121_final.pt'
    # ms_l1_model.load_state_dict(torch.load(ms_l1_model_path))
    # ms_l1_model.eval()


    # models = [m1_model,ms_l1_model]
    # models_names = ["L1128_6_epoch","msssim__128p_l1_3_epoch"]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomCrop(1000)])
    #     # transforms.RandomHorizontalFlip()
    #     # SIDD_Dataset.rotate_by_90_mul([0, 90, 180, 270])
    # ])

    # transform = transforms.Compose([
    #     transforms.ToTensor(),
    #     # transforms.RandomHorizontalFlip()
    #     # SIDD_Dataset.rotate_by_90_mul([0, 90, 180, 270])
    # ])
    # gt_im,noisy_im = create_cropped_images(gt_im,noisy_im,transform)
    # gt_im,noisy_im = load_im("../test_im/"+"17/",transform)
    # compare_save_images(gt_im,noisy_im,'../test_im/17/',models,models_names,"17",transform)
    #
    # gt_im, noisy_im = load_im("../big_images/chess/", transform)#image names should be gt.png and noisy.png. you can control crop with transforms
    # print*
    # print("Noisy Image PSNR is: " + str(PSNR(cv2.cvtColor(cv2.UMat(noisy_im),cv2.COLOR_BGR2RGB), gt_im)) + " SSIM is: " + str(calc_ssim(cv2.cvtColor(noisy_im,cv2.COLOR_BGR2RGB),gt_im)))

    # compare_save_images(gt_im,noisy_im,"../big_images/chess/",models,models_names,"chess_im_161221",transform1)

    ''
    # gt_im, noisy_im = load_im("../../../data/presnt_im/" + "39/", transform,True)
    # compare_save_images(gt_im,noisy_im,"../../../data/presnt_im/11522/",models,models_names,"39",transform)
    #
    # gt_im, noisy_im = load_im("../../../data/presnt_im/" + "111/", transform,True)
    # compare_save_images(gt_im,noisy_im,"../../../data/presnt_im/11522/" ,models,models_names,"111",transform)
    #
    # gt_im, noisy_im = load_im("../../../data/presnt_im/" + "150/", transform,True)
    # compare_save_images(gt_im, noisy_im, "../../../data/presnt_im/11522/", models, models_names, "150", transform)

    gt_im, noisy_im = load_nam_im("../nam_images/",str(4) , transform,True)
    compare_save_images(gt_im,noisy_im,"../nam_images/results/",models,models_names,"4",transform)

    gt_im, noisy_im = load_nam_im("../nam_images/",str(11), transform,True)
    compare_save_images(gt_im,noisy_im,"../nam_images/results/" ,models,models_names,"11",transform)

    gt_im, noisy_im = load_nam_im("../nam_images/",str(24), transform,True)
    compare_save_images(gt_im, noisy_im, "../nam_images/results/", models, models_names, "24", transform)

    # gt_im, noisy_im = load_im("../test_im/" + "60/", transform,False)
    # compare_save_images(gt_im, noisy_im, '../test_im/60/', models, models_names, "60", transform)
    #
    # gt_im, noisy_im = load_im("../test_im/" + "165/", transform,False)
    # compare_save_images(gt_im, noisy_im, '../test_im/165/', models, models_names, "165", transform)
    return 0
    transform = transforms.Compose([
        transforms.ToTensor()
        # transforms.RandomHorizontalFlip()
        # SIDD_Dataset.rotate_by_90_mul([0, 90, 180, 270])
    ])
    test_on_test_set('../test_set/gt/','../test_set/noisy/','../models/res_model1.pt',test_image,transform)
    test_on_test_set()
    # test_image('../test_set/gt/','../test_set/noisy/','../models/LabLoss_6921_l1.pt',transform,show=False,psnr=False,ssim=False)
    return 0
    noisy_path = 'PycharmProjects/Data/presentation/ISO_3200_C_2_cr.png'#'../test/0046_002_G4_00400_00350_3200_L/0046_NOISY_SRGB_010.PNG'
    gt_path = noisy_path.replace('NOISY', 'GT')
    #
    transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.RandomCrop(850)
         ])
    # #model = ridnet.RIDNET(args)
    model_path = 'PycharmProjects/Dnnproj/models/27621_l2.pt'
    #
    psnr, ssim = test_image(None,noisy_path,model_path,transform, show=False)
    print("psnr is: "+ str(psnr) + " ssim is: "+ str(ssim))
    return 0
Ejemplo n.º 3
0
# cv2.waitKey(0)
# print(os.getcwd())
# dog_path = '../../../PycharmProjects/Data/presentation/dog_rni15_2.png'
# dog2_path = '../../../PycharmProjects/Data/presentation/pred/presentation_dog_rni15_2.png'

train_dataloders = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=2)
validation_dataloders = torch.utils.data.DataLoader(validation_dataset,
                                                    batch_size=16,
                                                    shuffle=True,
                                                    num_workers=2)

# cv2.imshow("loaders",train_dataloders)
model = ridnet.RIDNET(args)
# num_trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
# print("the number of trainable weights is: ", num_trainable_params)
# model.load_state_dict(torch.load(model_path+'Y_L1_241021_final.pt'))
# dog_im = Image.open(dog_path)
# dog_im = transform(dog_im)
# dog2_im = Image.open(dog2_path)
# dog2_im = transform(dog2_im)
# dog2_im = dog2_im.unsqueeze(0)
# dog_im = dog_im.unsqueeze(0)
# dogs = torch.cat((dog_im,dog2_im),dim=0)
# lab_dog = utility.rgb2xyz(dogs)

# print(lab_dog.shape)
#
# print(lab_dog)