def train(config): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # device = torch.device('cpu') dehaze_net = net.dehaze_net().to(device) dehaze_net.apply(weights_init) train_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path) val_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path, mode="val") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) criterion = nn.MSELoss().to(device) optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr, weight_decay=config.weight_decay) dehaze_net.train() for epoch in range(config.num_epochs): for iteration, (img_orig, img_haze) in enumerate(train_loader): img_orig = img_orig.to(device) img_haze = img_haze.to(device) clean_image = dehaze_net(img_haze) loss = criterion(clean_image, img_orig) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(dehaze_net.parameters(),config.grad_clip_norm) optimizer.step() if ((iteration+1) % config.display_iter) == 0: print("epoch", epoch, " Loss at iteration", iteration+1, ":", loss.item()) if ((iteration+1) % config.snapshot_iter) == 0: torch.save(dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + 'iteration' + str(iteration+1)+ '.pth') # Validation Stage for iter_val, (img_orig, img_haze) in enumerate(val_loader): img_orig = img_orig.to(device) img_haze = img_haze.to(device) clean_image = dehaze_net(img_haze) torchvision.utils.save_image(torch.cat((img_haze, clean_image, img_orig),0), config.sample_output_folder+str(iter_val+1)+".jpg") torch.save(dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + 'iteration' + str(iteration+1)+ '.pth') torch.save(dehaze_net.state_dict(), config.snapshots_folder + "dehazer.pth")
def dehaze_image(image_path): data_hazy = Image.open(image_path) data_hazy = (np.asarray(data_hazy) / 255.0) data_hazy = torch.from_numpy(data_hazy).float() data_hazy = data_hazy.permute(2, 0, 1) data_hazy = data_hazy.cuda().unsqueeze(0) dehaze_net = net.dehaze_net().cuda() dehaze_net.load_state_dict(torch.load('snapshots/dehazer.pth')) clean_image = dehaze_net(data_hazy) torchvision.utils.save_image(torch.cat((data_hazy, clean_image), 0), "results/" + image_path.split("/")[-1])
def dehaze_image(image_path): data_hazy = Image.open(image_path) data_hazy = (np.asarray(data_hazy)/255.0) data_hazy = torch.from_numpy(data_hazy).float() data_hazy = data_hazy.permute(2,0,1) data_hazy = data_hazy.cuda().unsqueeze(0) dehaze_net = net.dehaze_net().cuda() dehaze_net = nn.DataParallel(dehaze_net, device_ids=[0,1,2,3]) dehaze_net.load_state_dict(torch.load('result/dehazer.pth')) clean_image = dehaze_net(data_hazy) #torchvision.utils.save_image(torch.cat((data_hazy, clean_image),0), "results/" + image_path.split("/")[-1]) torchvision.utils.save_image(clean_image, "my_results/" + image_path.split("/")[-1])
def dehaze_image(data_hazy): # print(image_path) # data_hazy = Image.open(image_path) # data_hazy = (np.asarray(data_hazy)/255.0) with torch.no_grad(): data_hazy = data_hazy / 255 data_hazy = torch.from_numpy(data_hazy).float() data_hazy = data_hazy.permute(2, 1, 0) data_hazy = data_hazy.cuda().unsqueeze(0) dehaze_net = net.dehaze_net().cuda() current_path = os.path.abspath(__file__) father_path = os.path.dirname(current_path) model_path = os.path.join(father_path, 'snapshots/dehazer.pth') print(model_path) dehaze_net.load_state_dict(torch.load(model_path)) clean_image = dehaze_net(data_hazy) # torch_data.numpy() clean_image = clean_image.cpu().numpy() print(type(clean_image)) clean_image = np.squeeze(clean_image, axis=0) clean_image = clean_image.transpose((2, 1, 0)) * 255 clean_image = np.clip(clean_image, 0, 255) return clean_image.astype(np.uint8)
def clrImg(data_hazy): data_hazy = (data_hazy / 255.0) data_hazy = torch.from_numpy(data_hazy).float() data_hazy = data_hazy.permute(2, 0, 1) dehaze_net = net.dehaze_net() if torch.cuda.is_available(): dehaze_net = dehaze_net.cuda() dehaze_net.load_state_dict( torch.load(baseLoc + 'weights/deepdehaze/dehazer.pth')) data_hazy = data_hazy.cuda() else: dehaze_net.load_state_dict( torch.load(baseLoc + 'weights/deepdehaze/dehazer.pth', map_location=torch.device("cpu"))) gimp.progress_update(float(0.005)) gimp.displays_flush() data_hazy = data_hazy.unsqueeze(0) clean_image = dehaze_net(data_hazy) out = clean_image.detach().numpy()[0, :, :, :] * 255 out = np.clip(np.transpose(out, (1, 2, 0)), 0, 255).astype(np.uint8) return out
import torch import torchvision import net import numpy as np from PIL import Image import glob if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # device = torch.device('cpu') test_list = glob.glob("test_images/*") dehaze_net = net.dehaze_net().to(device) if torch.cuda.is_available(): dehaze_net.load_state_dict(torch.load('snapshots/dehazer.pth')) else: dehaze_net.load_state_dict( torch.load('snapshots/dehazer.pth', map_location=lambda storage, loc: storage)) for image in test_list: data_hazy = Image.open(image) data_hazy = (np.asarray(data_hazy) / 255.0) data_hazy = torch.from_numpy(data_hazy).float() data_hazy = data_hazy.permute(2, 0, 1) data_hazy = data_hazy.to(device).unsqueeze(0) clean_image = dehaze_net(data_hazy) torchvision.utils.save_image(torch.cat((data_hazy, clean_image), 0), "results\\" + image.split("\\")[-1]) print(image, "done!")
def train(config): use_gpu = config.use_gpu bk_width = config.block_width bk_height = config.block_height resize = config.resize bTest = config.bTest if use_gpu: dehaze_net = net.dehaze_net().cuda() else: dehaze_net = net.dehaze_net() if config.snap_train_data: dehaze_net.load_state_dict( torch.load(config.snapshots_train_folder + config.snap_train_data)) else: dehaze_net.apply(weights_init) print(dehaze_net) train_dataset = dataloader.dehazing_loader(config.orig_images_path, 'train', resize, bk_width, bk_height, bTest) val_dataset = dataloader.dehazing_loader(config.orig_images_path, "val", resize, bk_width, bk_height, bTest) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) if use_gpu: criterion = nn.MSELoss().cuda() else: criterion = nn.MSELoss() optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr, weight_decay=config.weight_decay) dehaze_net.train() # 同一組訓練資料跑 epoch 次 save_counter = 0 for epoch in range(config.num_epochs): # 有 iteration 張一起訓練. # img_orig , img_haze 是包含 iteration 個圖片的 tensor 資料集 , 訓練時會一口氣訓練 iteration 個圖片. # 有點像將圖片橫向拼起來 實際上是不同維度. if config.do_valid == 0: for iteration, (img_orig, img_haze, rgb, bl_num_width, bl_num_height, data_path) in enumerate(train_loader): if save_counter == 0: print("img_orig.size:") print(len(img_orig)) print("bl_num_width.type:") print(bl_num_width.type) print("shape:") print(bl_num_width.shape) # train stage num_width = int(bl_num_width[0].item()) num_height = int(bl_num_height[0].item()) full_bk_num = num_width * num_height display_block_iter = full_bk_num / config.display_block_iter for index in range(len(img_orig)): unit_img_orig = img_orig[index] unit_img_haze = img_haze[index] if save_counter == 0: print("unit_img_orig type:") print(unit_img_orig.type()) print("size:") print(unit_img_orig.size()) print("shape:") print(unit_img_orig.shape) ''' if bTest == 1: if save_counter ==0: numpy_ori = unit_img_orig.numpy().copy() print("data path:") print(data_path) print("index:"+str(index)) for i in range(3): for j in range(32): print("before:") print(numpy_ori[index][i][j]) print("after:") print(numpy_ori[index][i][j]*255) ''' if use_gpu: unit_img_orig = unit_img_orig.cuda() unit_img_haze = unit_img_haze.cuda() clean_image = dehaze_net(unit_img_haze) loss = criterion(clean_image, unit_img_orig) if torch.isnan(unit_img_haze).any() or torch.isinf( clean_image).any(): print("unit_img_haze:") print(unit_img_haze.shape) print(unit_img_haze) print("clean_image:") print(clean_image.shape) print(clean_image) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(dehaze_net.parameters(), config.grad_clip_norm) optimizer.step() # show loss every config.display_block_iter if ((index + 1) % display_block_iter) == 0: print("Loss at Epoch:" + str(epoch) + "_index:" + str(index + 1) + "/" + str(len(img_orig)) + "_iter:" + str(iteration + 1) + "_Loss value:" + str(loss.item())) # save snapshot every save_counter times if ((save_counter + 1) % config.snapshot_iter) == 0: save_name = "Epoch:" + str( epoch) + "_TrainTimes:" + str(save_counter + 1) + ".pth" torch.save(dehaze_net.state_dict(), config.snapshots_folder + save_name) # torch.save(dehaze_net.state_dict(), # config.snapshots_folder , "Epoch:", str(epoch), " # _TrainTimes:", str(save_counter+1), ".pth") save_counter = save_counter + 1 # Validation Stage # img_orig -> yuv444 # img_haze -> yuv420 for iter_val, (img_orig, img_haze, rgb, bl_num_width, bl_num_height, data_path) in enumerate(val_loader): sub_image_list = [] # after deep_learning image (yuv420) sub_image_list_no_deep = [] # yuv420 ori_sub_image_list = [] # yuv444 image rgb_image_list = [] # block ori image (rgb) rgb_list_from_sub = [] # rgb from clean image (yuv420) rgb_list_from_ori = [] # rgb from haze image (yuv420) for index in range(len(img_orig)): unit_img_orig = img_orig[index] unit_img_haze = img_haze[index] unit_img_rgb = rgb[index] # TODO: yuv444 ??? color is strange ... ''' if bTest == 1 and index == 0: numpy_ori = unit_img_orig.numpy().copy() print("data path:") print(data_path) print("index:" + str(index)) for i in range(3): for j in range(32): print(numpy_ori[index][i][j]) bTest = 0 ''' if use_gpu: unit_img_orig = unit_img_orig.cuda() unit_img_haze = unit_img_haze.cuda() unit_img_rgb = unit_img_rgb.cuda() clean_image = dehaze_net(unit_img_haze) sub_image_list.append(clean_image) sub_image_list_no_deep.append(unit_img_haze) ori_sub_image_list.append(unit_img_orig) rgb_image_list.append(unit_img_rgb) rgb_list_from_sub.append(yuv2rgb(clean_image)) rgb_list_from_ori.append(yuv2rgb(unit_img_haze)) print(data_path) temp_data_path = data_path[0] print('temp_data_path:') print(temp_data_path) orimage_name = temp_data_path.split("/")[-1] print(orimage_name) orimage_name = orimage_name.split(".")[0] print(orimage_name) num_width = int(bl_num_width[0].item()) num_height = int(bl_num_height[0].item()) full_bk_num = num_width * num_height # YUV420 & after deep learning # ------------------------------------------------------------------# image_all = torch.cat((sub_image_list[:num_width]), 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(sub_image_list[i:i + num_width], 3) image_all = torch.cat([image_all, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_yuv420_deep_learning.bmp" print(image_name) torchvision.utils.save_image( image_all, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv420_deep.bmp") # ------------------------------------------------------------------# # YUV420 & without deep learning # ------------------------------------------------------------------# image_all_ori_no_deep = torch.cat( (sub_image_list_no_deep[:num_width]), 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(sub_image_list_no_deep[i:i + num_width], 3) image_all_ori_no_deep = torch.cat( [image_all_ori_no_deep, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_yuv420_ori.bmp" print(image_name) torchvision.utils.save_image( image_all_ori_no_deep, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv420_ori.bmp") # ------------------------------------------------------------------# # YUV444 # ------------------------------------------------------------------# image_all_ori = torch.cat(ori_sub_image_list[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(ori_sub_image_list[i:i + num_width], 3) image_all_ori = torch.cat([image_all_ori, image_row], 2) image_name = config.sample_output_folder + str(iter_val + 1) + "_yuv444.bmp" print(image_name) # torchvision.utils.save_image(image_all_ori, image_name) torchvision.utils.save_image( image_all_ori, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv444.bmp") # ------------------------------------------------------------------# # block rgb (test) # ------------------------------------------------------------------# rgb_image_all = torch.cat(rgb_image_list[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_image_list[i:i + num_width], 3) ''' image_row = torch.cat((ori_sub_image_list[i],ori_sub_image_list[i +1]), 1) for j in range(i+2, num_width): image_row = torch.cat((image_row, ori_sub_image_list[j]), 1) ''' rgb_image_all = torch.cat([rgb_image_all, image_row], 2) image_name = config.sample_output_folder + str(iter_val + 1) + "_rgb.bmp" print(image_name) torchvision.utils.save_image( rgb_image_all, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_rgb.bmp") # ------------------------------------------------------------------# # ------------------------------------------------------------------# rgb_from_420_image_all_clear = torch.cat( rgb_list_from_sub[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_list_from_sub[i:i + num_width], 3) rgb_from_420_image_all_clear = torch.cat( [rgb_from_420_image_all_clear, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_rgb_from_clean_420.bmp" print(image_name) torchvision.utils.save_image( rgb_from_420_image_all_clear, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_rgb_from_clean_420.bmp") # ------------------------------------------------------------------# # ------------------------------------------------------------------# rgb_from_420_image_all_haze = torch.cat( rgb_list_from_ori[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_list_from_ori[i:i + num_width], 3) rgb_from_420_image_all_haze = torch.cat( [rgb_from_420_image_all_haze, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_rgb_from_haze_420.bmp" print(image_name) torchvision.utils.save_image( rgb_from_420_image_all_haze, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "__rgb_from_haze_420.bmp") # ------------------------------------------------------------------# # To compute PSNR as a measure, use lower case function from the library. # ------------------------------------------------------------------# # rgb_from_420_image_all_haze rgb_image_all # rgb_from_420_image_all_clear rgb_image_all psnr_index = piq.psnr(rgb_from_420_image_all_haze, rgb_image_all, data_range=1., reduction='none') print(f"PSNR haze: {psnr_index.item():0.4f}") psnr_index = piq.psnr(rgb_from_420_image_all_clear, rgb_image_all, data_range=1., reduction='none') print(f"PSNR clear: {psnr_index.item():0.4f}") # ------------------------------------------------------------------# # To compute SSIM as a measure, use lower case function from the library. # ------------------------------------------------------------------# ssim_index = piq.ssim(rgb_from_420_image_all_haze, rgb_image_all, data_range=1.) ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)( rgb_from_420_image_all_haze, rgb_image_all) print( f"SSIM haze index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}" ) ssim_index = piq.ssim(rgb_from_420_image_all_clear, rgb_image_all, data_range=1.) ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)( rgb_from_420_image_all_clear, rgb_image_all) print( f"SSIM clear index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}" ) # ------------------------------------------------------------------# torch.save(dehaze_net.state_dict(), config.snapshots_folder + "dehazer.pth")