Exemplo n.º 1
0
def train(config):
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	# device = torch.device('cpu')
	dehaze_net = net.dehaze_net().to(device)
	dehaze_net.apply(weights_init)

	train_dataset = dataloader.dehazing_loader(config.orig_images_path,
											 config.hazy_images_path)		
	val_dataset = dataloader.dehazing_loader(config.orig_images_path,
											 config.hazy_images_path, mode="val")		
	train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
	val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)

	criterion = nn.MSELoss().to(device)
	optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
	
	dehaze_net.train()

	for epoch in range(config.num_epochs):
		for iteration, (img_orig, img_haze) in enumerate(train_loader):

			img_orig = img_orig.to(device)
			img_haze = img_haze.to(device)

			clean_image = dehaze_net(img_haze)

			loss = criterion(clean_image, img_orig)

			optimizer.zero_grad()
			loss.backward()
			torch.nn.utils.clip_grad_norm_(dehaze_net.parameters(),config.grad_clip_norm)
			optimizer.step()

			if ((iteration+1) % config.display_iter) == 0:
				print("epoch", epoch, " Loss at iteration", iteration+1, ":", loss.item())
			if ((iteration+1) % config.snapshot_iter) == 0:
				
				torch.save(dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + 'iteration' + str(iteration+1)+ '.pth') 		

		# Validation Stage
		for iter_val, (img_orig, img_haze) in enumerate(val_loader):

			img_orig = img_orig.to(device)
			img_haze = img_haze.to(device)

			clean_image = dehaze_net(img_haze)

			torchvision.utils.save_image(torch.cat((img_haze, clean_image, img_orig),0), config.sample_output_folder+str(iter_val+1)+".jpg")

		torch.save(dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + 'iteration' + str(iteration+1)+ '.pth') 		
	torch.save(dehaze_net.state_dict(), config.snapshots_folder + "dehazer.pth") 
Exemplo n.º 2
0
def dehaze_image(image_path):

    data_hazy = Image.open(image_path)
    data_hazy = (np.asarray(data_hazy) / 255.0)

    data_hazy = torch.from_numpy(data_hazy).float()
    data_hazy = data_hazy.permute(2, 0, 1)
    data_hazy = data_hazy.cuda().unsqueeze(0)

    dehaze_net = net.dehaze_net().cuda()
    dehaze_net.load_state_dict(torch.load('snapshots/dehazer.pth'))

    clean_image = dehaze_net(data_hazy)
    torchvision.utils.save_image(torch.cat((data_hazy, clean_image), 0),
                                 "results/" + image_path.split("/")[-1])
Exemplo n.º 3
0
def dehaze_image(image_path):

	data_hazy = Image.open(image_path)
	data_hazy = (np.asarray(data_hazy)/255.0)

	data_hazy = torch.from_numpy(data_hazy).float()
	data_hazy = data_hazy.permute(2,0,1)
	data_hazy = data_hazy.cuda().unsqueeze(0)

	dehaze_net = net.dehaze_net().cuda()
	dehaze_net = nn.DataParallel(dehaze_net, device_ids=[0,1,2,3])
	dehaze_net.load_state_dict(torch.load('result/dehazer.pth'))

	clean_image = dehaze_net(data_hazy)
	#torchvision.utils.save_image(torch.cat((data_hazy, clean_image),0), "results/" + image_path.split("/")[-1])
	torchvision.utils.save_image(clean_image, "my_results/" + image_path.split("/")[-1])
Exemplo n.º 4
0
def dehaze_image(data_hazy):
    #    print(image_path)
    #    data_hazy = Image.open(image_path)
    #    data_hazy = (np.asarray(data_hazy)/255.0)
    with torch.no_grad():
        data_hazy = data_hazy / 255
        data_hazy = torch.from_numpy(data_hazy).float()
        data_hazy = data_hazy.permute(2, 1, 0)
        data_hazy = data_hazy.cuda().unsqueeze(0)
        dehaze_net = net.dehaze_net().cuda()
        current_path = os.path.abspath(__file__)
        father_path = os.path.dirname(current_path)
        model_path = os.path.join(father_path, 'snapshots/dehazer.pth')
        print(model_path)
        dehaze_net.load_state_dict(torch.load(model_path))
        clean_image = dehaze_net(data_hazy)
        #    torch_data.numpy()
        clean_image = clean_image.cpu().numpy()
        print(type(clean_image))
        clean_image = np.squeeze(clean_image, axis=0)
        clean_image = clean_image.transpose((2, 1, 0)) * 255
        clean_image = np.clip(clean_image, 0, 255)
    return clean_image.astype(np.uint8)
Exemplo n.º 5
0
def clrImg(data_hazy):
    data_hazy = (data_hazy / 255.0)
    data_hazy = torch.from_numpy(data_hazy).float()
    data_hazy = data_hazy.permute(2, 0, 1)
    dehaze_net = net.dehaze_net()

    if torch.cuda.is_available():
        dehaze_net = dehaze_net.cuda()
        dehaze_net.load_state_dict(
            torch.load(baseLoc + 'weights/deepdehaze/dehazer.pth'))
        data_hazy = data_hazy.cuda()
    else:
        dehaze_net.load_state_dict(
            torch.load(baseLoc + 'weights/deepdehaze/dehazer.pth',
                       map_location=torch.device("cpu")))

    gimp.progress_update(float(0.005))
    gimp.displays_flush()
    data_hazy = data_hazy.unsqueeze(0)
    clean_image = dehaze_net(data_hazy)
    out = clean_image.detach().numpy()[0, :, :, :] * 255
    out = np.clip(np.transpose(out, (1, 2, 0)), 0, 255).astype(np.uint8)
    return out
Exemplo n.º 6
0
import torch
import torchvision
import net
import numpy as np
from PIL import Image
import glob

if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    test_list = glob.glob("test_images/*")
    dehaze_net = net.dehaze_net().to(device)
    if torch.cuda.is_available():
        dehaze_net.load_state_dict(torch.load('snapshots/dehazer.pth'))
    else:
        dehaze_net.load_state_dict(
            torch.load('snapshots/dehazer.pth',
                       map_location=lambda storage, loc: storage))
    for image in test_list:
        data_hazy = Image.open(image)
        data_hazy = (np.asarray(data_hazy) / 255.0)
        data_hazy = torch.from_numpy(data_hazy).float()
        data_hazy = data_hazy.permute(2, 0, 1)
        data_hazy = data_hazy.to(device).unsqueeze(0)
        clean_image = dehaze_net(data_hazy)
        torchvision.utils.save_image(torch.cat((data_hazy, clean_image), 0),
                                     "results\\" + image.split("\\")[-1])
        print(image, "done!")
Exemplo n.º 7
0
def train(config):
    use_gpu = config.use_gpu
    bk_width = config.block_width
    bk_height = config.block_height
    resize = config.resize
    bTest = config.bTest

    if use_gpu:
        dehaze_net = net.dehaze_net().cuda()
    else:
        dehaze_net = net.dehaze_net()

    if config.snap_train_data:
        dehaze_net.load_state_dict(
            torch.load(config.snapshots_train_folder + config.snap_train_data))
    else:
        dehaze_net.apply(weights_init)
    print(dehaze_net)

    train_dataset = dataloader.dehazing_loader(config.orig_images_path,
                                               'train', resize, bk_width,
                                               bk_height, bTest)
    val_dataset = dataloader.dehazing_loader(config.orig_images_path, "val",
                                             resize, bk_width, bk_height,
                                             bTest)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=config.val_batch_size,
                                             shuffle=True,
                                             num_workers=config.num_workers,
                                             pin_memory=True)

    if use_gpu:
        criterion = nn.MSELoss().cuda()
    else:
        criterion = nn.MSELoss()

    optimizer = torch.optim.Adam(dehaze_net.parameters(),
                                 lr=config.lr,
                                 weight_decay=config.weight_decay)
    dehaze_net.train()

    # 同一組訓練資料跑 epoch 次
    save_counter = 0
    for epoch in range(config.num_epochs):
        # 有 iteration 張一起訓練.
        # img_orig , img_haze 是包含 iteration 個圖片的 tensor 資料集 , 訓練時會一口氣訓練 iteration 個圖片.
        # 有點像將圖片橫向拼起來 實際上是不同維度.
        if config.do_valid == 0:
            for iteration, (img_orig, img_haze, rgb, bl_num_width,
                            bl_num_height,
                            data_path) in enumerate(train_loader):
                if save_counter == 0:
                    print("img_orig.size:")
                    print(len(img_orig))
                    print("bl_num_width.type:")
                    print(bl_num_width.type)
                    print("shape:")
                    print(bl_num_width.shape)

                # train stage
                num_width = int(bl_num_width[0].item())
                num_height = int(bl_num_height[0].item())
                full_bk_num = num_width * num_height
                display_block_iter = full_bk_num / config.display_block_iter
                for index in range(len(img_orig)):
                    unit_img_orig = img_orig[index]
                    unit_img_haze = img_haze[index]
                    if save_counter == 0:
                        print("unit_img_orig type:")
                        print(unit_img_orig.type())
                        print("size:")
                        print(unit_img_orig.size())
                        print("shape:")
                        print(unit_img_orig.shape)
                    '''
                    if bTest == 1:
                        if save_counter ==0:
                            numpy_ori = unit_img_orig.numpy().copy()
                            print("data path:")
                            print(data_path)
                            print("index:"+str(index))

                            for i in range(3):
                                for j in range(32):
                                    print("before:")
                                    print(numpy_ori[index][i][j])
                                    print("after:")
                                    print(numpy_ori[index][i][j]*255)
                    '''

                    if use_gpu:
                        unit_img_orig = unit_img_orig.cuda()
                        unit_img_haze = unit_img_haze.cuda()

                    clean_image = dehaze_net(unit_img_haze)

                    loss = criterion(clean_image, unit_img_orig)

                    if torch.isnan(unit_img_haze).any() or torch.isinf(
                            clean_image).any():
                        print("unit_img_haze:")
                        print(unit_img_haze.shape)
                        print(unit_img_haze)

                        print("clean_image:")
                        print(clean_image.shape)
                        print(clean_image)

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(dehaze_net.parameters(),
                                                   config.grad_clip_norm)
                    optimizer.step()

                    # show loss every config.display_block_iter
                    if ((index + 1) % display_block_iter) == 0:
                        print("Loss at Epoch:" + str(epoch) + "_index:" +
                              str(index + 1) + "/" + str(len(img_orig)) +
                              "_iter:" + str(iteration + 1) + "_Loss value:" +
                              str(loss.item()))
                    # save snapshot every save_counter times
                    if ((save_counter + 1) % config.snapshot_iter) == 0:
                        save_name = "Epoch:" + str(
                            epoch) + "_TrainTimes:" + str(save_counter +
                                                          1) + ".pth"
                        torch.save(dehaze_net.state_dict(),
                                   config.snapshots_folder + save_name)
                        # torch.save(dehaze_net.state_dict(),
                        #           config.snapshots_folder , "Epoch:", str(epoch), "
                        #           _TrainTimes:", str(save_counter+1), ".pth")

                    save_counter = save_counter + 1

        # Validation Stage
        # img_orig -> yuv444
        # img_haze -> yuv420
        for iter_val, (img_orig, img_haze, rgb, bl_num_width, bl_num_height,
                       data_path) in enumerate(val_loader):
            sub_image_list = []  # after deep_learning image (yuv420)
            sub_image_list_no_deep = []  # yuv420
            ori_sub_image_list = []  # yuv444 image

            rgb_image_list = []  # block ori image (rgb)
            rgb_list_from_sub = []  # rgb from clean image (yuv420)
            rgb_list_from_ori = []  # rgb from haze image  (yuv420)

            for index in range(len(img_orig)):
                unit_img_orig = img_orig[index]
                unit_img_haze = img_haze[index]
                unit_img_rgb = rgb[index]

                # TODO: yuv444 ??? color is strange ...
                '''
                if bTest == 1 and index == 0:
                    numpy_ori = unit_img_orig.numpy().copy()
                    print("data path:")
                    print(data_path)
                    print("index:" + str(index))

                    for i in range(3):
                        for j in range(32):
                            print(numpy_ori[index][i][j])
                    bTest = 0
                '''
                if use_gpu:
                    unit_img_orig = unit_img_orig.cuda()
                    unit_img_haze = unit_img_haze.cuda()
                    unit_img_rgb = unit_img_rgb.cuda()

                clean_image = dehaze_net(unit_img_haze)

                sub_image_list.append(clean_image)
                sub_image_list_no_deep.append(unit_img_haze)
                ori_sub_image_list.append(unit_img_orig)
                rgb_image_list.append(unit_img_rgb)

                rgb_list_from_sub.append(yuv2rgb(clean_image))
                rgb_list_from_ori.append(yuv2rgb(unit_img_haze))

            print(data_path)
            temp_data_path = data_path[0]
            print('temp_data_path:')
            print(temp_data_path)
            orimage_name = temp_data_path.split("/")[-1]
            print(orimage_name)
            orimage_name = orimage_name.split(".")[0]
            print(orimage_name)

            num_width = int(bl_num_width[0].item())
            num_height = int(bl_num_height[0].item())
            full_bk_num = num_width * num_height

            # YUV420 & after deep learning
            # ------------------------------------------------------------------#
            image_all = torch.cat((sub_image_list[:num_width]), 3)

            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(sub_image_list[i:i + num_width], 3)
                image_all = torch.cat([image_all, image_row], 2)

            image_name = config.sample_output_folder + str(
                iter_val + 1) + "_yuv420_deep_learning.bmp"
            print(image_name)

            torchvision.utils.save_image(
                image_all, config.sample_output_folder + "Epoch:" +
                str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_yuv420_deep.bmp")
            # ------------------------------------------------------------------#

            # YUV420 & without deep learning
            # ------------------------------------------------------------------#
            image_all_ori_no_deep = torch.cat(
                (sub_image_list_no_deep[:num_width]), 3)

            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(sub_image_list_no_deep[i:i + num_width],
                                      3)
                image_all_ori_no_deep = torch.cat(
                    [image_all_ori_no_deep, image_row], 2)

            image_name = config.sample_output_folder + str(
                iter_val + 1) + "_yuv420_ori.bmp"
            print(image_name)

            torchvision.utils.save_image(
                image_all_ori_no_deep, config.sample_output_folder + "Epoch:" +
                str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_yuv420_ori.bmp")
            # ------------------------------------------------------------------#

            # YUV444
            # ------------------------------------------------------------------#
            image_all_ori = torch.cat(ori_sub_image_list[:num_width], 3)

            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(ori_sub_image_list[i:i + num_width], 3)
                image_all_ori = torch.cat([image_all_ori, image_row], 2)

            image_name = config.sample_output_folder + str(iter_val +
                                                           1) + "_yuv444.bmp"
            print(image_name)
            # torchvision.utils.save_image(image_all_ori, image_name)
            torchvision.utils.save_image(
                image_all_ori, config.sample_output_folder + "Epoch:" +
                str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_yuv444.bmp")
            # ------------------------------------------------------------------#

            # block rgb (test)
            # ------------------------------------------------------------------#
            rgb_image_all = torch.cat(rgb_image_list[:num_width], 3)
            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(rgb_image_list[i:i + num_width], 3)
                '''
                image_row = torch.cat((ori_sub_image_list[i],ori_sub_image_list[i +1]), 1)
                for j in range(i+2, num_width):
                    image_row = torch.cat((image_row, ori_sub_image_list[j]), 1)
                '''
                rgb_image_all = torch.cat([rgb_image_all, image_row], 2)
            image_name = config.sample_output_folder + str(iter_val +
                                                           1) + "_rgb.bmp"
            print(image_name)
            torchvision.utils.save_image(
                rgb_image_all, config.sample_output_folder + "Epoch:" +
                str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_rgb.bmp")
            # ------------------------------------------------------------------#

            # ------------------------------------------------------------------#
            rgb_from_420_image_all_clear = torch.cat(
                rgb_list_from_sub[:num_width], 3)
            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(rgb_list_from_sub[i:i + num_width], 3)
                rgb_from_420_image_all_clear = torch.cat(
                    [rgb_from_420_image_all_clear, image_row], 2)

            image_name = config.sample_output_folder + str(
                iter_val + 1) + "_rgb_from_clean_420.bmp"
            print(image_name)
            torchvision.utils.save_image(
                rgb_from_420_image_all_clear, config.sample_output_folder +
                "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "_rgb_from_clean_420.bmp")
            # ------------------------------------------------------------------#

            # ------------------------------------------------------------------#
            rgb_from_420_image_all_haze = torch.cat(
                rgb_list_from_ori[:num_width], 3)
            for i in range(num_width, full_bk_num, num_width):
                image_row = torch.cat(rgb_list_from_ori[i:i + num_width], 3)
                rgb_from_420_image_all_haze = torch.cat(
                    [rgb_from_420_image_all_haze, image_row], 2)
            image_name = config.sample_output_folder + str(
                iter_val + 1) + "_rgb_from_haze_420.bmp"
            print(image_name)
            torchvision.utils.save_image(
                rgb_from_420_image_all_haze, config.sample_output_folder +
                "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" +
                orimage_name + "__rgb_from_haze_420.bmp")
            # ------------------------------------------------------------------#

            # To compute PSNR as a measure, use lower case function from the library.
            # ------------------------------------------------------------------#
            # rgb_from_420_image_all_haze rgb_image_all
            # rgb_from_420_image_all_clear rgb_image_all
            psnr_index = piq.psnr(rgb_from_420_image_all_haze,
                                  rgb_image_all,
                                  data_range=1.,
                                  reduction='none')
            print(f"PSNR haze: {psnr_index.item():0.4f}")

            psnr_index = piq.psnr(rgb_from_420_image_all_clear,
                                  rgb_image_all,
                                  data_range=1.,
                                  reduction='none')
            print(f"PSNR clear: {psnr_index.item():0.4f}")
            # ------------------------------------------------------------------#

            # To compute SSIM as a measure, use lower case function from the library.
            # ------------------------------------------------------------------#

            ssim_index = piq.ssim(rgb_from_420_image_all_haze,
                                  rgb_image_all,
                                  data_range=1.)
            ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(
                rgb_from_420_image_all_haze, rgb_image_all)
            print(
                f"SSIM haze index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}"
            )

            ssim_index = piq.ssim(rgb_from_420_image_all_clear,
                                  rgb_image_all,
                                  data_range=1.)
            ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)(
                rgb_from_420_image_all_clear, rgb_image_all)
            print(
                f"SSIM clear index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}"
            )
            # ------------------------------------------------------------------#

        torch.save(dehaze_net.state_dict(),
                   config.snapshots_folder + "dehazer.pth")