def get_loss(): if args.loss == LossType.L1: return nn.L1Loss() if args.loss == LossType.SmoothL1: return nn.SmoothL1Loss(beta=0.01) if args.loss == LossType.L2: return nn.MSELoss() if args.loss == LossType.SSIM: return PIQLoss(piq.SSIMLoss()) if args.loss == LossType.VIF: return PIQLoss(piq.VIFLoss()) if args.loss == LossType.LPIPS: return PIQLoss(piq.LPIPS()) if args.loss == LossType.DISTS: return PIQLoss(piq.DISTS()) raise ValueError("Unknown loss")
lambda x, y: sk.structural_similarity( x, y, win_size=11, multichannel=True, gaussian_weights=True, ), 'piq.ssim': piq.ssim, 'kornia.SSIM-halfloss': kornia.SSIM( window_size=11, reduction='mean', ), 'piq.SSIM-loss': piq.SSIMLoss(), 'IQA.SSIM-loss': IQA.SSIM(), 'vainf.SSIM': vainf.SSIM(data_range=1.), 'piqa.SSIM': piqa.SSIM(), }), 'MS-SSIM': (2, { 'piq.ms_ssim': piq.multi_scale_ssim, 'piq.MS_SSIM-loss': piq.MultiScaleSSIMLoss(), 'IQA.MS_SSIM-loss': IQA.MS_SSIM(), 'vainf.MS_SSIM': vainf.MS_SSIM(data_range=1.), 'piqa.MS_SSIM': piqa.MS_SSIM(), }), 'LPIPS': (
def train(config): use_gpu = config.use_gpu bk_width = config.block_width bk_height = config.block_height resize = config.resize bTest = config.bTest if use_gpu: dehaze_net = net.dehaze_net().cuda() else: dehaze_net = net.dehaze_net() if config.snap_train_data: dehaze_net.load_state_dict( torch.load(config.snapshots_train_folder + config.snap_train_data)) else: dehaze_net.apply(weights_init) print(dehaze_net) train_dataset = dataloader.dehazing_loader(config.orig_images_path, 'train', resize, bk_width, bk_height, bTest) val_dataset = dataloader.dehazing_loader(config.orig_images_path, "val", resize, bk_width, bk_height, bTest) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) if use_gpu: criterion = nn.MSELoss().cuda() else: criterion = nn.MSELoss() optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr, weight_decay=config.weight_decay) dehaze_net.train() # 同一組訓練資料跑 epoch 次 save_counter = 0 for epoch in range(config.num_epochs): # 有 iteration 張一起訓練. # img_orig , img_haze 是包含 iteration 個圖片的 tensor 資料集 , 訓練時會一口氣訓練 iteration 個圖片. # 有點像將圖片橫向拼起來 實際上是不同維度. if config.do_valid == 0: for iteration, (img_orig, img_haze, rgb, bl_num_width, bl_num_height, data_path) in enumerate(train_loader): if save_counter == 0: print("img_orig.size:") print(len(img_orig)) print("bl_num_width.type:") print(bl_num_width.type) print("shape:") print(bl_num_width.shape) # train stage num_width = int(bl_num_width[0].item()) num_height = int(bl_num_height[0].item()) full_bk_num = num_width * num_height display_block_iter = full_bk_num / config.display_block_iter for index in range(len(img_orig)): unit_img_orig = img_orig[index] unit_img_haze = img_haze[index] if save_counter == 0: print("unit_img_orig type:") print(unit_img_orig.type()) print("size:") print(unit_img_orig.size()) print("shape:") print(unit_img_orig.shape) ''' if bTest == 1: if save_counter ==0: numpy_ori = unit_img_orig.numpy().copy() print("data path:") print(data_path) print("index:"+str(index)) for i in range(3): for j in range(32): print("before:") print(numpy_ori[index][i][j]) print("after:") print(numpy_ori[index][i][j]*255) ''' if use_gpu: unit_img_orig = unit_img_orig.cuda() unit_img_haze = unit_img_haze.cuda() clean_image = dehaze_net(unit_img_haze) loss = criterion(clean_image, unit_img_orig) if torch.isnan(unit_img_haze).any() or torch.isinf( clean_image).any(): print("unit_img_haze:") print(unit_img_haze.shape) print(unit_img_haze) print("clean_image:") print(clean_image.shape) print(clean_image) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(dehaze_net.parameters(), config.grad_clip_norm) optimizer.step() # show loss every config.display_block_iter if ((index + 1) % display_block_iter) == 0: print("Loss at Epoch:" + str(epoch) + "_index:" + str(index + 1) + "/" + str(len(img_orig)) + "_iter:" + str(iteration + 1) + "_Loss value:" + str(loss.item())) # save snapshot every save_counter times if ((save_counter + 1) % config.snapshot_iter) == 0: save_name = "Epoch:" + str( epoch) + "_TrainTimes:" + str(save_counter + 1) + ".pth" torch.save(dehaze_net.state_dict(), config.snapshots_folder + save_name) # torch.save(dehaze_net.state_dict(), # config.snapshots_folder , "Epoch:", str(epoch), " # _TrainTimes:", str(save_counter+1), ".pth") save_counter = save_counter + 1 # Validation Stage # img_orig -> yuv444 # img_haze -> yuv420 for iter_val, (img_orig, img_haze, rgb, bl_num_width, bl_num_height, data_path) in enumerate(val_loader): sub_image_list = [] # after deep_learning image (yuv420) sub_image_list_no_deep = [] # yuv420 ori_sub_image_list = [] # yuv444 image rgb_image_list = [] # block ori image (rgb) rgb_list_from_sub = [] # rgb from clean image (yuv420) rgb_list_from_ori = [] # rgb from haze image (yuv420) for index in range(len(img_orig)): unit_img_orig = img_orig[index] unit_img_haze = img_haze[index] unit_img_rgb = rgb[index] # TODO: yuv444 ??? color is strange ... ''' if bTest == 1 and index == 0: numpy_ori = unit_img_orig.numpy().copy() print("data path:") print(data_path) print("index:" + str(index)) for i in range(3): for j in range(32): print(numpy_ori[index][i][j]) bTest = 0 ''' if use_gpu: unit_img_orig = unit_img_orig.cuda() unit_img_haze = unit_img_haze.cuda() unit_img_rgb = unit_img_rgb.cuda() clean_image = dehaze_net(unit_img_haze) sub_image_list.append(clean_image) sub_image_list_no_deep.append(unit_img_haze) ori_sub_image_list.append(unit_img_orig) rgb_image_list.append(unit_img_rgb) rgb_list_from_sub.append(yuv2rgb(clean_image)) rgb_list_from_ori.append(yuv2rgb(unit_img_haze)) print(data_path) temp_data_path = data_path[0] print('temp_data_path:') print(temp_data_path) orimage_name = temp_data_path.split("/")[-1] print(orimage_name) orimage_name = orimage_name.split(".")[0] print(orimage_name) num_width = int(bl_num_width[0].item()) num_height = int(bl_num_height[0].item()) full_bk_num = num_width * num_height # YUV420 & after deep learning # ------------------------------------------------------------------# image_all = torch.cat((sub_image_list[:num_width]), 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(sub_image_list[i:i + num_width], 3) image_all = torch.cat([image_all, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_yuv420_deep_learning.bmp" print(image_name) torchvision.utils.save_image( image_all, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv420_deep.bmp") # ------------------------------------------------------------------# # YUV420 & without deep learning # ------------------------------------------------------------------# image_all_ori_no_deep = torch.cat( (sub_image_list_no_deep[:num_width]), 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(sub_image_list_no_deep[i:i + num_width], 3) image_all_ori_no_deep = torch.cat( [image_all_ori_no_deep, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_yuv420_ori.bmp" print(image_name) torchvision.utils.save_image( image_all_ori_no_deep, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv420_ori.bmp") # ------------------------------------------------------------------# # YUV444 # ------------------------------------------------------------------# image_all_ori = torch.cat(ori_sub_image_list[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(ori_sub_image_list[i:i + num_width], 3) image_all_ori = torch.cat([image_all_ori, image_row], 2) image_name = config.sample_output_folder + str(iter_val + 1) + "_yuv444.bmp" print(image_name) # torchvision.utils.save_image(image_all_ori, image_name) torchvision.utils.save_image( image_all_ori, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv444.bmp") # ------------------------------------------------------------------# # block rgb (test) # ------------------------------------------------------------------# rgb_image_all = torch.cat(rgb_image_list[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_image_list[i:i + num_width], 3) ''' image_row = torch.cat((ori_sub_image_list[i],ori_sub_image_list[i +1]), 1) for j in range(i+2, num_width): image_row = torch.cat((image_row, ori_sub_image_list[j]), 1) ''' rgb_image_all = torch.cat([rgb_image_all, image_row], 2) image_name = config.sample_output_folder + str(iter_val + 1) + "_rgb.bmp" print(image_name) torchvision.utils.save_image( rgb_image_all, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_rgb.bmp") # ------------------------------------------------------------------# # ------------------------------------------------------------------# rgb_from_420_image_all_clear = torch.cat( rgb_list_from_sub[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_list_from_sub[i:i + num_width], 3) rgb_from_420_image_all_clear = torch.cat( [rgb_from_420_image_all_clear, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_rgb_from_clean_420.bmp" print(image_name) torchvision.utils.save_image( rgb_from_420_image_all_clear, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_rgb_from_clean_420.bmp") # ------------------------------------------------------------------# # ------------------------------------------------------------------# rgb_from_420_image_all_haze = torch.cat( rgb_list_from_ori[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_list_from_ori[i:i + num_width], 3) rgb_from_420_image_all_haze = torch.cat( [rgb_from_420_image_all_haze, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_rgb_from_haze_420.bmp" print(image_name) torchvision.utils.save_image( rgb_from_420_image_all_haze, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "__rgb_from_haze_420.bmp") # ------------------------------------------------------------------# # To compute PSNR as a measure, use lower case function from the library. # ------------------------------------------------------------------# # rgb_from_420_image_all_haze rgb_image_all # rgb_from_420_image_all_clear rgb_image_all psnr_index = piq.psnr(rgb_from_420_image_all_haze, rgb_image_all, data_range=1., reduction='none') print(f"PSNR haze: {psnr_index.item():0.4f}") psnr_index = piq.psnr(rgb_from_420_image_all_clear, rgb_image_all, data_range=1., reduction='none') print(f"PSNR clear: {psnr_index.item():0.4f}") # ------------------------------------------------------------------# # To compute SSIM as a measure, use lower case function from the library. # ------------------------------------------------------------------# ssim_index = piq.ssim(rgb_from_420_image_all_haze, rgb_image_all, data_range=1.) ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)( rgb_from_420_image_all_haze, rgb_image_all) print( f"SSIM haze index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}" ) ssim_index = piq.ssim(rgb_from_420_image_all_clear, rgb_image_all, data_range=1.) ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)( rgb_from_420_image_all_clear, rgb_image_all) print( f"SSIM clear index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}" ) # ------------------------------------------------------------------# torch.save(dehaze_net.state_dict(), config.snapshots_folder + "dehazer.pth")
def main(): # Read RGB image and it's noisy version x = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(2, 0, 1) / 255. y = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1) / 255. if torch.cuda.is_available(): # Move to GPU to make computaions faster x = x.cuda() y = y.cuda() # To compute BRISQUE score as a measure, use lower case function from the library brisque_index: torch.Tensor = piq.brisque(x, data_range=1., reduction='none') # In order to use BRISQUE as a loss function, use corresponding PyTorch module. # Note: the back propagation is not available using torch==1.5.0. # Update the environment with latest torch and torchvision. brisque_loss: torch.Tensor = piq.BRISQUELoss(data_range=1., reduction='none')(x) print( f"BRISQUE index: {brisque_index.item():0.4f}, loss: {brisque_loss.item():0.4f}" ) # To compute Content score as a loss function, use corresponding PyTorch module # By default VGG16 model is used, but any feature extractor model is supported. # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently. # Use weights parameter. See other options in class docstring. content_loss = piq.ContentLoss(feature_extractor="vgg16", layers=("relu3_3", ), reduction='none')(x, y) print(f"ContentLoss: {content_loss.item():0.4f}") # To compute DISTS as a loss function, use corresponding PyTorch module # By default input images are normalized with ImageNet statistics before forwarding through VGG16 model. # If there is no need to normalize the data, use mean=[0.0, 0.0, 0.0] and std=[1.0, 1.0, 1.0]. dists_loss = piq.DISTS(reduction='none')(x, y) print(f"DISTS: {dists_loss.item():0.4f}") # To compute FSIM as a measure, use lower case function from the library fsim_index: torch.Tensor = piq.fsim(x, y, data_range=1., reduction='none') # In order to use FSIM as a loss function, use corresponding PyTorch module fsim_loss = piq.FSIMLoss(data_range=1., reduction='none')(x, y) print( f"FSIM index: {fsim_index.item():0.4f}, loss: {fsim_loss.item():0.4f}") # To compute GMSD as a measure, use lower case function from the library # This is port of MATLAB version from the authors of original paper. # In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval. gmsd_index: torch.Tensor = piq.gmsd(x, y, data_range=1., reduction='none') # In order to use GMSD as a loss function, use corresponding PyTorch module: gmsd_loss: torch.Tensor = piq.GMSDLoss(data_range=1., reduction='none')(x, y) print( f"GMSD index: {gmsd_index.item():0.4f}, loss: {gmsd_loss.item():0.4f}") # To compute HaarPSI as a measure, use lower case function from the library # This is port of MATLAB version from the authors of original paper. haarpsi_index: torch.Tensor = piq.haarpsi(x, y, data_range=1., reduction='none') # In order to use HaarPSI as a loss function, use corresponding PyTorch module haarpsi_loss: torch.Tensor = piq.HaarPSILoss(data_range=1., reduction='none')(x, y) print( f"HaarPSI index: {haarpsi_index.item():0.4f}, loss: {haarpsi_loss.item():0.4f}" ) # To compute LPIPS as a loss function, use corresponding PyTorch module lpips_loss: torch.Tensor = piq.LPIPS(reduction='none')(x, y) print(f"LPIPS: {lpips_loss.item():0.4f}") # To compute MDSI as a measure, use lower case function from the library mdsi_index: torch.Tensor = piq.mdsi(x, y, data_range=1., reduction='none') # In order to use MDSI as a loss function, use corresponding PyTorch module mdsi_loss: torch.Tensor = piq.MDSILoss(data_range=1., reduction='none')(x, y) print( f"MDSI index: {mdsi_index.item():0.4f}, loss: {mdsi_loss.item():0.4f}") # To compute MS-SSIM index as a measure, use lower case function from the library: ms_ssim_index: torch.Tensor = piq.multi_scale_ssim(x, y, data_range=1.) # In order to use MS-SSIM as a loss function, use corresponding PyTorch module: ms_ssim_loss = piq.MultiScaleSSIMLoss(data_range=1., reduction='none')(x, y) print( f"MS-SSIM index: {ms_ssim_index.item():0.4f}, loss: {ms_ssim_loss.item():0.4f}" ) # To compute Multi-Scale GMSD as a measure, use lower case function from the library # It can be used both as a measure and as a loss function. In any case it should me minimized. # By defualt scale weights are initialized with values from the paper. # You can change them by passing a list of 4 variables to scale_weights argument during initialization # Note that input tensors should contain images with height and width equal 2 ** number_of_scales + 1 at least. ms_gmsd_index: torch.Tensor = piq.multi_scale_gmsd(x, y, data_range=1., chromatic=True, reduction='none') # In order to use Multi-Scale GMSD as a loss function, use corresponding PyTorch module ms_gmsd_loss: torch.Tensor = piq.MultiScaleGMSDLoss(chromatic=True, data_range=1., reduction='none')(x, y) print( f"MS-GMSDc index: {ms_gmsd_index.item():0.4f}, loss: {ms_gmsd_loss.item():0.4f}" ) # To compute PSNR as a measure, use lower case function from the library. psnr_index = piq.psnr(x, y, data_range=1., reduction='none') print(f"PSNR index: {psnr_index.item():0.4f}") # To compute PieAPP as a loss function, use corresponding PyTorch module: pieapp_loss: torch.Tensor = piq.PieAPP(reduction='none', stride=32)(x, y) print(f"PieAPP loss: {pieapp_loss.item():0.4f}") # To compute SSIM index as a measure, use lower case function from the library: ssim_index = piq.ssim(x, y, data_range=1.) # In order to use SSIM as a loss function, use corresponding PyTorch module: ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(x, y) print( f"SSIM index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}") # To compute Style score as a loss function, use corresponding PyTorch module: # By default VGG16 model is used, but any feature extractor model is supported. # Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently. # Use weights parameter. See other options in class docstring. style_loss = piq.StyleLoss(feature_extractor="vgg16", layers=("relu3_3", ))(x, y) print(f"Style: {style_loss.item():0.4f}") # To compute TV as a measure, use lower case function from the library: tv_index: torch.Tensor = piq.total_variation(x) # In order to use TV as a loss function, use corresponding PyTorch module: tv_loss: torch.Tensor = piq.TVLoss(reduction='none')(x) print(f"TV index: {tv_index.item():0.4f}, loss: {tv_loss.item():0.4f}") # To compute VIF as a measure, use lower case function from the library: vif_index: torch.Tensor = piq.vif_p(x, y, data_range=1.) # In order to use VIF as a loss function, use corresponding PyTorch class: vif_loss: torch.Tensor = piq.VIFLoss(sigma_n_sq=2.0, data_range=1.)(x, y) print(f"VIFp index: {vif_index.item():0.4f}, loss: {vif_loss.item():0.4f}") # To compute VSI score as a measure, use lower case function from the library: vsi_index: torch.Tensor = piq.vsi(x, y, data_range=1.) # In order to use VSI as a loss function, use corresponding PyTorch module: vsi_loss: torch.Tensor = piq.VSILoss(data_range=1.)(x, y) print(f"VSI index: {vsi_index.item():0.4f}, loss: {vsi_loss.item():0.4f}")
'kornia.PSNR': kornia.PSNRLoss(max_val=1.), 'piqa.PSNR': piqa.PSNR(), }), 'SSIM': (2, { 'sk.ssim': lambda x, y: sk.structural_similarity( x, y, win_size=11, multichannel=True, gaussian_weights=True, ), 'piq.ssim': piq.ssim, 'kornia.SSIM-halfloss': kornia.SSIM( window_size=11, reduction='mean', ), 'piq.SSIM-loss': piq.SSIMLoss(), 'IQA.SSIM-loss': IQA.SSIM(), 'vainf.SSIM': vainf.SSIM(data_range=1.), 'piqa.SSIM': piqa.SSIM(), }), 'MS-SSIM': (2, { 'piq.ms_ssim': piq.multi_scale_ssim, 'piq.MS_SSIM-loss': piq.MultiScaleSSIMLoss(), 'IQA.MS_SSIM-loss': IQA.MS_SSIM(), 'vainf.MS_SSIM': vainf.MS_SSIM(data_range=1.), 'piqa.MS_SSIM': piqa.MS_SSIM(), }), 'LPIPS': (2, { 'piq.LPIPS': piq.LPIPS(), 'IQA.LPIPS': IQA.LPIPSvgg(), 'piqa.LPIPS': piqa.LPIPS(network='vgg')