Exemplo n.º 1
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    #PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']
    PSF_grid = np.load('./data/Heide_PSF_plano_small.npz')['PSF']

    PSF_grid = PSF_grid.astype(np.float32)

    gx, gy = PSF_grid.shape[:2]
    for xx in range(gx):
        for yy in range(gy):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))

    # ----------------------------------------
    # load model
    # ----------------------------------------
    stage = 8
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=stage,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")

    model_code = 'iter800'
    loaded_state = torch.load(
        '/home/xiu/databag/deblur/models/plano/uabcnet_{}.pth'.format(
            model_code))
    #strip_state = strip_prefix_if_present(loaded_state,prefix="p.")
    model.load_state_dict(loaded_state, strict=True)

    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    for img_id in range(1, 237):
        #for img_id in range(1,12):
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video1-3/res/2_{:03d}.bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video/{:08d}.bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/databag/deblur/ICCV2021/suo_image/{}/AC254-075-A-ML-Zemax(ZMX).bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/ResolutionChart/Reso.bmp')
        img_L = cv2.imread(
            '/home/xiu/databag/deblur/ICCV2021/MPI_data/drain/blurry.jpg')
        img_L = img_L.astype(np.float32)
        img_L = img_L[38:-39, 74:-74]
        img_L = cv2.resize(img_L, dsize=None, fx=0.5, fy=0.5)
        #img_L = np.pad(img_L,((1,1),(61,62),(0,0)),mode='edge')

        W, H = img_L.shape[:2]

        print(gx, gy)
        num_patch = [gx, gy]
        #positional alpha-beta parameters for HQS
        ab_numpy = np.loadtxt(
            '/home/xiu/databag/deblur/models/plano/ab_{}.txt'.format(
                model_code)).astype(np.float32).reshape(gx, gy, stage * 2, 3)

        ab = torch.tensor(ab_numpy, device=device, requires_grad=False)

        t0 = time.time()

        px_start = 0
        py_start = 0

        PSF_patch = PSF_grid[px_start:px_start + num_patch[0],
                             py_start:py_start + num_patch[1]]
        #block_expand = 1
        patch_L = img_L[px_start * W // gx:(px_start + num_patch[0]) * W // gx,
                        py_start * H // gy:(py_start + num_patch[1]) * H //
                        gy, :]

        p_W, p_H = patch_L.shape[:2]
        expand = max(PSF_grid.shape[2] // 2, p_W // 16)
        block_expand = expand
        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (p_W + block_expand * 2, p_H + block_expand * 2))
        #centralize
        patch_L_wrap = np.hstack((patch_L_wrap[:, -block_expand:, :],
                                  patch_L_wrap[:, :p_H + block_expand, :]))
        patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:, :, :],
                                  patch_L_wrap[:p_W + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        k_all = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                k_all.append(util.single2tensor4(PSF_patch[w_, h_]))
        k = torch.cat(k_all, dim=0)

        [x, k] = [el.to(device) for el in [x, k]]

        ab_patch = F.softplus(ab[px_start:px_start + num_patch[0],
                                 py_start:py_start + num_patch[1]])
        cd = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)

        x_E = model.forward_patchwise(x, k, cd, num_patch, [W // gx, H // gy])
        x_E = x_E[..., block_expand:block_expand + p_W,
                  block_expand:block_expand + p_H]

        patch_L = patch_L_wrap.astype(np.uint8)

        patch_E = util.tensor2uint(x_E)

        #patch_E_z = np.hstack((patch_E_all[::2]))
        #patch_E_x = np.hstack((patch_E_all[1::2]))

        #patch_E_show = np.vstack((patch_E_z,patch_E_x))
        #if block_expand>0:
        #	show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E))
        #else:
        #	show = np.hstack((patch_L,patch_E))

        #cv2.imshow('stage',patch_E_show)
        #cv2.imshow('HL',show)
        #cv2.imshow('RGB',rgb)
        #key = cv2.waitKey(-1)
        #if key==ord('n'):
        #	break

        t1 = time.time()

        print(t1 - t0)

        # print(i)
        xk = patch_E
        # #zk = zk.astype(np.uint8)
        xk = xk.astype(np.uint8)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/new_image/image/ours-{}.png'.format(img_id),xk)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/video_deblur/{:08d}.png'.format(img_id),xk)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/cap_result/1_{:03d}.png'.format(img_id),xk)
        cv2.imshow('xx', xk)
        cv2.imshow('img_L', patch_L.astype(np.uint8))
        key = cv2.waitKey(-1)
        if key == ord('q'):
            break
Exemplo n.º 2
0
def main():
	# ----------------------------------------
	# load kernels
	# ----------------------------------------
	#PSF_grid = np.load('./data/Schuler_PSF01.npz')['PSF']
	#PSF_grid = np.load('./data/Schuler_PSF_facade.npz')['PSF']
	PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF']
	#PSF_grid = np.load('./data/Schuler_PSF03.npz')['PSF']
	#PSF_grid = np.load('./data/PSF.npz')['PSF']
	#print(PSF_grid.shape)
	
	PSF_grid = PSF_grid.astype(np.float32)

	gx,gy = PSF_grid.shape[:2]
	for xx in range(gx):
		for yy in range(gy):
			PSF_grid[xx,yy] = PSF_grid[xx,yy]/np.sum(PSF_grid[xx,yy],axis=(0,1))

	#PSF_grid = PSF_grid[:,1:-1,...]
	# ----------------------------------------
	# load model
	# ----------------------------------------
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
					nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")

	loaded_state = torch.load('./usrnet_ZEMAX.pth')
	#strip_state = strip_prefix_if_present(loaded_state,prefix="p.")
	model.load_state_dict(loaded_state, strict=True)

	model.eval()
	#model.train()
	for _, v in model.named_parameters():
		v.requires_grad = False
	#	v.requires_grad = False
	model = model.to(device)


	for img_id in range(100):
		img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video/{:08d}.bmp'.format(img_id))
		img_L = img_L.astype(np.float32)

		img_E = np.zeros_like(img_L)

		weight_E = np.zeros_like(img_L)

		patch_size = 2*128
		num_patch = 2
		p_size = patch_size//num_patch
		expand = PSF_grid.shape[2]

		ab_numpy = np.loadtxt('ab_ZEMAX.txt').astype(np.float32).reshape(6,8,16,3)
		ab = torch.tensor(ab_numpy,device=device,requires_grad=False)

		#save img_L

		t0 = time.time()
		#while running:
		for px_start in range(0,6-2+1,2):
			for py_start in range(0,8-2+1,2):

				PSF_patch = PSF_grid[px_start:px_start+num_patch,py_start:py_start+num_patch]

				patch_L = img_L[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:]
				block_expand = expand
				#block_expand = 1
				if block_expand > 0:
					patch_L_wrap = util_deblur.wrap_boundary_liu(patch_L,(patch_size+block_expand*2,patch_size+block_expand*2))
					#centralize
					patch_L_wrap = np.hstack((patch_L_wrap[:,-block_expand:,:],patch_L_wrap[:,:patch_size+block_expand,:]))
					patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:,:,:],patch_L_wrap[:patch_size+block_expand,:,:]))
				else:
					patch_L_wrap = patch_L
				if block_expand>0:
					x = util.uint2single(patch_L_wrap)
				else:
					x = util.uint2single(patch_L)

				x = util.single2tensor4(x)

				# x_blocky = torch.cat(torch.chunk(x,num_patch,dim=2),dim=0)
				# x_blocky = torch.cat(torch.chunk(x_blocky,num_patch,dim=3),dim=0)

				k_all = []
				for w_ in range(num_patch):
					for h_ in range(num_patch):
						k_all.append(util.single2tensor4(PSF_patch[h_,w_]))
				k = torch.cat(k_all,dim=0)

				[x,k] = [el.to(device) for el in [x,k]]

				cd = F.softplus(ab[px_start:px_start+num_patch,py_start:py_start+num_patch])
				cd = cd.view(num_patch**2,2*8,3)

				x_E = model.forward_patchwise(x,k,cd,[num_patch,num_patch],[patch_size//num_patch,patch_size//num_patch])

				patch_L = patch_L_wrap.astype(np.uint8)

				patch_E = util.tensor2uint(x_E)
				patch_E_all = [util.tensor2uint(pp) for pp in x_E]

				#patch_E_z = np.hstack((patch_E_all[::2]))
				#patch_E_x = np.hstack((patch_E_all[1::2]))

				#patch_E_show = np.vstack((patch_E_z,patch_E_x))
				#if block_expand>0:
				#	show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E))
				#else:
				#	show = np.hstack((patch_L,patch_E))

		
				#get kernel
				for i in range(8):
					img_E_deconv[i][px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += patch_E_all[-2][expand:-expand,expand:-expand]
					img_E_denoise[i][px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += patch_E_all[-1][expand:-expand,expand:-expand]
				weight_E[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += 1.0

				#cv2.imshow('stage',patch_E_show)
				#cv2.imshow('HL',show)
				#cv2.imshow('RGB',rgb)
				#key = cv2.waitKey(-1)
				#if key==ord('n'):
				#	break

		t1 = time.time()

		print(t1-t0)
		img_E = img_E/weight_E
		img_E_deconv = [pp/weight_E for pp in img_E_deconv]
		img_E_denoise = [pp/weight_E for pp in img_E_denoise]

		# print(i)
		xk = img_E_denoise[-1]
		# #zk = zk.astype(np.uint8)
		xk = xk.astype(np.uint8)
		#cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/video_deblur/{:08d}.png'.format(img_id),xk)
		cv2.imshow('xx',xk)
		cv2.imshow('img_L',img_L.astype(np.uint8))
		cv2.waitKey(-1)
Exemplo n.º 3
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']

    PSF_grid = PSF_grid.astype(np.float32)

    gx, gy = PSF_grid.shape[:2]
    for xx in range(gx):
        for yy in range(gy):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))

    # ----------------------------------------
    # load model
    # ----------------------------------------
    stage = 8
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=stage,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")

    model_code = 'iter17000'
    loaded_state = torch.load(
        '/home/xiu/databag/deblur/models/ZEMAX/uabcnet_{}.pth'.format(
            model_code))
    model.load_state_dict(loaded_state, strict=True)

    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    img_names = glob.glob(
        '/home/xiu/databag/deblur/ICCV2021/suo_image/*/AC254-075-A-ML-Zemax(ZMX).bmp'
    )
    img_names.sort()
    for img_id, img_name in enumerate(img_names):
        img_L = cv2.imread(img_name)
        img_L = img_L.astype(np.float32)
        W, H = img_L.shape[:2]
        num_patch = [6, 8]
        #positional alpha-beta parameters for HQS
        ab_numpy = np.loadtxt(
            '/home/xiu/databag/deblur/models/ZEMAX/ab_{}.txt'.format(
                model_code)).astype(np.float32).reshape(gx, gy, stage * 2, 3)
        ab = torch.tensor(ab_numpy, device=device, requires_grad=False)

        #save img_L

        t0 = time.time()

        px_start = 0
        py_start = 0

        PSF_patch = PSF_grid[px_start:px_start + num_patch[0],
                             py_start:py_start + num_patch[1]]
        #block_expand = 1
        patch_L = img_L[px_start * W // gx:(px_start + num_patch[0]) * W // gx,
                        py_start * H // gy:(py_start + num_patch[1]) * H //
                        gy, :]

        p_W, p_H = patch_L.shape[:2]
        expand = max(PSF_grid.shape[2] // 2, p_W // 16)
        block_expand = expand
        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (p_W + block_expand * 2, p_H + block_expand * 2))
        #centralize
        patch_L_wrap = np.hstack((patch_L_wrap[:, -block_expand:, :],
                                  patch_L_wrap[:, :p_H + block_expand, :]))
        patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:, :, :],
                                  patch_L_wrap[:p_W + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        k_all = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                k_all.append(util.single2tensor4(PSF_patch[w_, h_]))
        k = torch.cat(k_all, dim=0)

        [x, k] = [el.to(device) for el in [x, k]]

        ab_patch = F.softplus(ab[px_start:px_start + num_patch[0],
                                 py_start:py_start + num_patch[1]])
        cd = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)

        x_E = model.forward_patchwise(x, k, cd, num_patch, [W // gx, H // gy])
        x_E = x_E[..., block_expand:block_expand + p_W,
                  block_expand:block_expand + p_H]

        patch_L = patch_L_wrap.astype(np.uint8)

        patch_E = util.tensor2uint(x_E)

        t1 = time.time()

        print('[{}/{}]: {} s per frame'.format(img_id, len(img_names),
                                               t1 - t0))

        xk = patch_E
        xk = xk.astype(np.uint8)

        cv2.imshow('res', xk)
        cv2.imshow('input', patch_L.astype(np.uint8))

        key = cv2.waitKey(-1)
        if key == ord('q'):
            break
Exemplo n.º 4
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']
    PSF_grid = PSF_grid.astype(np.float32)
    gx, gy = PSF_grid.shape[:2]

    k_tensor = []
    for yy in range(gy):
        for xx in range(gx):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))
            k_tensor.append(util.single2tensor4(PSF_grid[xx, yy]))

    k_tensor = torch.cat(k_tensor, dim=0)
    inv_weight = util_deblur.get_inv_spatial_weight(k_tensor)

    # ----------------------------------------
    # load model
    # ----------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.proj.load_state_dict(torch.load('./data/usrnet_pretrain.pth'),
                               strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    # ----------------------------------------
    # load training data
    # ----------------------------------------
    imgs = glob.glob('./DIV2K_train/*.png', recursive=True)
    imgs.sort()

    # ----------------------------------------
    # positional lambda\mu for HQS
    # ----------------------------------------
    stage = 8
    ab_buffer = np.ones((gx, gy, 2 * stage, 3), dtype=np.float32) * 0.1
    #ab_buffer[:,:,0,:] = 0.01
    ab = torch.tensor(ab_buffer, device=device, requires_grad=True)

    # ----------------------------------------
    # build optimizer
    # ----------------------------------------
    params = []
    params += [{"params": [ab], "lr": 0.0005}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 0.0001}]
    optimizer = torch.optim.Adam(params, lr=0.0001, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.9)

    patch_size = [128, 128]
    expand = PSF_grid.shape[2] // 2
    patch_num = [2, 2]

    global_iter = 0

    running = True

    while running:
        #alpha.beta
        img_idx = np.random.randint(len(imgs))
        img = imgs[img_idx]
        img_H = cv2.imread(img)
        w, h = img_H.shape[:2]

        #focus on the edges

        mode = np.random.randint(5)
        px_start = np.random.randint(0, gx - patch_num[0] + 1)
        py_start = np.random.randint(0, gy - patch_num[1] + 1)
        if mode == 0:
            px_start = 0
        if mode == 1:
            px_start = gx - patch_num[0]
        if mode == 2:
            py_start = 0
        if mode == 3:
            py_start = gy - patch_num[1]

        x_start = np.random.randint(
            0, w - patch_size[0] * patch_num[0] - expand * 2 + 1)
        y_start = np.random.randint(
            0, h - patch_size[1] * patch_num[1] - expand * 2 + 1)
        PSF_patch = PSF_grid[px_start:px_start + patch_num[0],
                             py_start:py_start + patch_num[1]]

        patch_H = img_H[x_start:x_start+patch_size[0]*patch_num[0]+expand*2,\
         y_start:y_start+patch_size[1]*patch_num[1]+expand*2]
        patch_L = util_deblur.blockConv2d(patch_H, PSF_patch, expand)

        block_expand = max(patch_size[0] // 8, expand)

        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (patch_size[0] * patch_num[0] + block_expand * 2,
                      patch_size[1] * patch_num[1] + block_expand * 2))
        patch_L_wrap = np.hstack(
            (patch_L_wrap[:, -block_expand:, :],
             patch_L_wrap[:, :patch_size[1] * patch_num[1] + block_expand, :]))
        patch_L_wrap = np.vstack(
            (patch_L_wrap[-block_expand:, :, :],
             patch_L_wrap[:patch_size[0] * patch_num[0] + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        x_gt = util.uint2single(patch_H[expand:-expand, expand:-expand])
        x_gt = util.single2tensor4(x_gt)
        inv_weight_patch = torch.ones_like(x_gt)

        k_local = []

        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                inv_weight_patch[0, 0,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             0]
                inv_weight_patch[0, 1,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             1]
                inv_weight_patch[0, 2,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             2]
                k_local.append(k_tensor[w_ + h_ * patch_num[0]:w_ +
                                        h_ * patch_num[0] + 1])

        k = torch.cat(k_local, dim=0)
        [x, x_gt, k, inv_weight_patch
         ] = [el.to(device) for el in [x, x_gt, k, inv_weight_patch]]
        ab_patch = F.softplus(ab[px_start:px_start + patch_num[0],
                                 py_start:py_start + patch_num[1]])
        cd = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)
        x_E = model.forward_patchwise(x, k, cd, patch_num, patch_size)

        predict = x_E[...,block_expand:block_expand+patch_size[0]*patch_num[0],\
         block_expand:block_expand+patch_size[1]*patch_num[1]]
        loss = F.l1_loss(predict.div(inv_weight_patch),
                         x_gt.div(inv_weight_patch))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        print('iter:{},loss {}'.format(global_iter + 1, loss.item()))

        patch_L = patch_L_wrap.astype(np.uint8)
        patch_E = util.tensor2uint(x_E)[block_expand:-block_expand,
                                        block_expand:-block_expand]

        show = np.hstack((patch_H[expand:-expand, expand:-expand],
                          patch_L[block_expand:-block_expand,
                                  block_expand:-block_expand], patch_E))

        cv2.imshow('HL', show)
        key = cv2.waitKey(1)

        global_iter += 1

        #change the save period
        if global_iter % 100 == 0:
            ab_numpy = ab.detach().cpu().numpy().flatten()
            torch.save(
                model.state_dict(),
                './ZEMAX_model/usrnet_ZEMAX_iter{}.pth'.format(global_iter))
            np.savetxt('./ZEMAX_model/ab_ZEMAX_iter{}.txt'.format(global_iter),
                       ab_numpy)
        if key == ord('q'):
            running = False
            break
    ab_numpy = ab.detach().cpu().numpy().flatten()
    torch.save(model.state_dict(), './ZEMAX_model/usrnet_ZEMAX.pth')
    np.savetxt('./ZEMAX_model/ab_ZEMAX.txt', ab_numpy)
Exemplo n.º 5
0
def main():
    #0. global config
    #scale factor
    sf = 4
    stage = 8
    patch_size = [32, 32]
    patch_num = [3, 3]

    #1. local PSF
    #shape: gx,gy,kw,kw,3
    all_PSFs = load_kernels('./data')

    #2. local model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    #model.proj.load_state_dict(torch.load('./data/usrnet_pretrain.pth'),strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    #positional lambda, mu for HQS, set as free trainable parameters here.
    ab_buffer = np.ones(
        (patch_num[0], patch_num[1], 2 * stage, 3), dtype=np.float32) * 0.1
    ab = torch.tensor(ab_buffer, device=device, requires_grad=True)

    params = []
    params += [{"params": [ab], "lr": 0.0005}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 0.0001}]
    optimizer = torch.optim.Adam(params, lr=0.0001, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.9)

    #3.load training data
    imgs_H = glob.glob('/home/xiu/databag/deblur/images/DIV2K_train/*.png',
                       recursive=True)
    imgs_L = glob.glob('/home/xiu/databag/deblur/images/DIV2K_lr/*.png',
                       recursive=True)
    imgs_H.sort()
    imgs_L.sort()

    global_iter = 0
    N_maxiter = 200000

    #def get_train_pairs()

    for i in range(N_maxiter):

        t0 = time.time()
        #draw random image.
        img_idx = np.random.randint(len(imgs_H))

        img_H = cv2.imread(imgs_H[img_idx])

        #img2 = imgs_L[img_idx]
        #img_L = cv2.imread(img2)
        #draw random patch from image
        #a. without img_L

        #draw random kernel
        PSF_grid = draw_random_kernel(all_PSFs, patch_num)

        patch_L, patch_H, patch_psf = draw_training_pair(
            img_H, PSF_grid, sf, patch_num, patch_size)
        #b.	with img_L
        #patch_L, patch_H, patch_psf,px_start, py_start,block_expand = draw_training_pair(img_H, PSF_grid, sf, patch_num, patch_size, img_L)
        t_data = time.time() - t0

        x = util.uint2single(patch_L)
        x = util.single2tensor4(x)
        x_gt = util.uint2single(patch_H)
        x_gt = util.single2tensor4(x_gt)

        k_local = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                k_local.append(util.single2tensor4(patch_psf[w_, h_]))
        k = torch.cat(k_local, dim=0)
        [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

        ab_patch = F.softplus(ab)
        ab_patch_v = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
        ab_patch_v = torch.cat(ab_patch_v, dim=0)

        x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                         [patch_size[0], patch_size[1]], sf)

        loss = F.l1_loss(x_E, x_gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        t_iter = time.time() - t0 - t_data

        print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'.
              format(global_iter + 1, loss.item(), t_data, t_iter))

        patch_L = cv2.resize(patch_L,
                             dsize=None,
                             fx=sf,
                             fy=sf,
                             interpolation=cv2.INTER_NEAREST)
        #patch_L = patch_L[block_expand*sf:-block_expand*sf,block_expand*sf:-block_expand*sf]
        patch_E = util.tensor2uint((x_E))
        show = np.hstack((patch_H, patch_L, patch_E))
        cv2.imshow('H,L,E', show)
        key = cv2.waitKey(1)
        global_iter += 1

        # for logging model weight.
        # if global_iter % 100 ==0:
        # 	torch.save(model.state_dict(),'./logs/uabcnet_{}.pth'.format(global_iter))

        if key == ord('q'):
            break
        if key == ord('s'):
            torch.save(model.state_dict(), './logs/uabcnet.pth')

    torch.save(model.state_dict(), './logs/uabcnet.pth')
Exemplo n.º 6
0
def main():
    #0. global config
    sf = 4
    stage = 8
    patch_size = [32, 32]
    patch_num = [2, 2]

    #1. local PSF
    all_PSFs = load_kernels('./data')

    #2. load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.load_state_dict(torch.load('./data/uabcnet_final.pth'), strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    #3. set up discriminator
    model_D = gan.PatchDiscriminator(5)
    model_D = model_D.to(device)

    gan_loss = gan.GANLoss(mode='lsgan')
    gan_loss = gan_loss.to(device)
    fake_images = ImagePool(16)

    #positional lambda, mu for HQS.
    ab_buffer = np.zeros((patch_num[0], patch_num[1], 2 * stage, 3))
    ab_buffer[:, :, ::2, :] = 0.01
    ab_buffer[:, :, 1::2, :] = 0.1
    ab = torch.tensor(ab_buffer,
                      dtype=torch.float32,
                      device=device,
                      requires_grad=True)
    params = []
    params += [{"params": [ab], "lr": 5e-4}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 1e-5}]

    #
    params_D = []
    params_D += list(model_D.parameters())

    optimizer = torch.optim.Adam(params, lr=1e-4, betas=(0.9, 0.999))
    optimizer_D = torch.optim.Adam(params_D, lr=1e-4, betas=(0.9, 0.999))

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.9)

    #3.load training data
    imgs_H = glob.glob('/home/xiu/databag/deblur/images/DIV2K_train/*.png',
                       recursive=True)
    imgs_H.sort()

    global_iter = 0
    N_maxiter = 200000

    PSF_grid = draw_random_kernel(all_PSFs)

    for i in range(N_maxiter):

        t0 = time.time()
        img_idx = np.random.randint(len(imgs_H))
        img_H = cv2.imread(imgs_H[img_idx])

        #draw random kernel

        patch_L, patch_H, patch_psf = draw_training_pair(
            img_H, PSF_grid, sf, patch_num, patch_size)
        t_data = time.time() - t0

        x = util.uint2single(patch_L)
        x = util.single2tensor4(x)
        x_gt = util.uint2single(patch_H)
        x_gt = util.single2tensor4(x_gt)

        k_local = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                k_local.append(util.single2tensor4(patch_psf[w_, h_]))
        k = torch.cat(k_local, dim=0)
        [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

        ab_patch = F.softplus(ab)
        ab_patch_v = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
        ab_patch_v = torch.cat(ab_patch_v, dim=0)

        x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                         [patch_size[0], patch_size[1]], sf)

        loss_l1 = F.l1_loss(x_E, x_gt)
        loss_gan = gan_loss(model_D(x_E), True)
        loss = loss_l1 + loss_gan
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred_real = model_D(x_gt)
        loss_D_real = gan_loss(pred_real, True)
        fake = fake_images.query(x_E)
        pred_fake = model_D(fake.detach())
        loss_D_fake = gan_loss(pred_fake, False)
        loss_D = (loss_D_fake + loss_D_real) * 0.5
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        scheduler.step()

        t_iter = time.time() - t0 - t_data

        print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'.
              format(global_iter + 1, loss.item(), t_data, t_iter))

        patch_L = cv2.resize(patch_L,
                             dsize=None,
                             fx=sf,
                             fy=sf,
                             interpolation=cv2.INTER_NEAREST)
        patch_E = util.tensor2uint((x_E))
        show = np.hstack((patch_H, patch_L, patch_E))
        cv2.imshow('H,L,E', show)
        key = cv2.waitKey(1)
        global_iter += 1

        if key == ord('q'):
            break

    ab_numpy = ab.detach().cpu().numpy().flatten()
    torch.save(model.state_dict(), './data/uabcnet_finetune.pth')
    np.savetxt('./data/ab_finetune.txt', ab_numpy)
Exemplo n.º 7
0
def main():
	#0. global config
	#scale factor
	sf = 4	
	stage = 5
	patch_size = [32,32]
	patch_num = [2,2]

	#1. local PSF
	#shape: gx,gy,kw,kw,3
	all_PSFs = load_kernels('./data')


	#2. local model
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = net(n_iter=5, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
					nb=2,sf=sf, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
	model.load_state_dict(torch.load('./logs/uabcnet_final.pth'),strict=True)
	model.train()
	for _, v in model.named_parameters():
		v.requires_grad = True
	model = model.to(device)

	#positional lambda, mu for HQS, set as free trainable parameters here.

	#ab_buffer = np.loadtxt('./data/ab.txt').reshape((patch_num[0],patch_num[1],2*stage,3)).astype(np.float32)
	ab_pretrain = np.loadtxt('./logs/ab_pretrain.txt').reshape((1,1,2*stage,3)).astype(np.float32)

	ab_buffer = np.ones((patch_num[0],patch_num[1],2*stage,3),dtype=np.float32)
	for xx in range(patch_num[0]):
		for yy in range(patch_num[1]):
			ab_buffer[xx,yy] = ab_pretrain[0,0]

	ab = torch.tensor(ab_buffer,device=device,requires_grad=True)
	params = []
	params += [{"params":[ab],"lr":0.0001}]
	for key,value in model.named_parameters():
		params += [{"params":[value],"lr":1e-6}]

	optimizer = torch.optim.Adam(params,lr=0.0001,betas=(0.9,0.999))
	scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=1000,gamma=0.9)

	#3.load training data
	imgs_H = glob.glob('./DIV2K_train/*.png',recursive=True)
	imgs_H.sort()

	global_iter = 0

	all_PSNR = []
	N_maxiter = 4000

	PSF_grid = draw_random_kernel(all_PSFs,patch_num)
	#def get_train_pairs()

	for i in range(N_maxiter):

		t0 = time.time()
		#draw random image.
		img_idx = np.random.randint(len(imgs_H))

		img_H = cv2.imread(imgs_H[img_idx])

		patch_L,patch_H,patch_psf = draw_training_pair(img_H,PSF_grid,sf,patch_num,patch_size)
		#b.	with img_L
		#patch_L, patch_H, patch_psf,px_start, py_start,block_expand = draw_training_pair(img_H, PSF_grid, sf, patch_num, patch_size, img_L)
		t_data = time.time()-t0

		x = util.uint2single(patch_L)
		x = util.single2tensor4(x)
		x_gt = util.uint2single(patch_H)
		x_gt = util.single2tensor4(x_gt)

		k_local = []
		for h_ in range(patch_num[1]):
			for w_ in range(patch_num[0]):
				k_local.append(util.single2tensor4(patch_psf[w_,h_]))
		k = torch.cat(k_local,dim=0)
		[x,x_gt,k] = [el.to(device) for el in [x,x_gt,k]]
		
		ab_patch = F.softplus(ab)
		ab_patch_v = []
		for h_ in range(patch_num[1]):
			for w_ in range(patch_num[0]):
				ab_patch_v.append(ab_patch[w_:w_+1,h_])
		ab_patch_v = torch.cat(ab_patch_v,dim=0)

		x_E = model.forward_patchwise_SR(x,k,ab_patch_v,patch_num,[patch_size[0],patch_size[1]],sf)

		loss = F.l1_loss(x_E,x_gt)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		scheduler.step()

		t_iter = time.time() - t0 - t_data

		print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'.format(global_iter+1,loss.item(),t_data,t_iter))

		patch_L = cv2.resize(patch_L,dsize=None,fx=sf,fy=sf,interpolation=cv2.INTER_NEAREST)
		#patch_L = patch_L[block_expand*sf:-block_expand*sf,block_expand*sf:-block_expand*sf]
		patch_E = util.tensor2uint((x_E))
		show = np.hstack((patch_H,patch_L,patch_E))
		cv2.imshow('H,L,E',show)
		key = cv2.waitKey(1)
		global_iter+= 1

		if i % 1000 ==0:
			cv2.imwrite(os.path.join('./result', 'test' , 'resultE-{:04d}.png'.format(i + 1)), patch_E)
			cv2.imwrite(os.path.join('./result', 'test', 'resultL-{:04d}.png'.format(i + 1)), patch_L)
			cv2.imwrite(os.path.join('./result', 'test', 'resultH-{:04d}.png'.format(i + 1)), patch_H)

		# if key==ord('q'):
		# 	break
		# if key==ord('s'):
		# 	ab_numpy = ab.detach().cpu().numpy().flatten()
		# 	np.savetxt('./data/ab.txt',ab_numpy)


	ab_numpy = ab.detach().cpu().numpy().flatten()
	torch.save(model.state_dict(),'./data/uabcnet_finetune.pth')
	np.savetxt('./data/ab_finetune.txt',ab_numpy)
Exemplo n.º 8
0
def main():
    #0. global config
    #scale factor
    sf = 4
    stage = 5
    patch_size = [32, 32]
    patch_num = [2, 2]

    #1. local PSF
    #shape: gx,gy,kw,kw,3
    all_PSFs = load_kernels('./data')

    #2. local model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=5,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    #loaded_state_dict=  torch.load('./data/uabcnet_final.pth')
    loaded_state_dict = torch.load('./data/uabcnet_finetune.pth')
    model.load_state_dict(loaded_state_dict, strict=True)
    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    #positional lambda, mu for HQS, set as free trainable parameters here.
    ab_buffer = np.loadtxt('./data/ab_finetune.txt').reshape(
        (patch_num[0], patch_num[1], 2 * stage, 3)).astype(np.float32)
    #ab[2x2,2*stage,3]

    #ab_buffer = np.ones((patch_num[0],patch_num[1],2*stage,3),dtype=np.float32)*0.1
    ab = torch.tensor(ab_buffer, device=device, requires_grad=False)
    ab = F.softplus(ab)

    #3.load training data
    imgs_H = glob.glob('./DIV2K_train/*.png', recursive=True)
    imgs_H.sort()

    global_iter = 0
    N_maxiter = 1000

    PSF_grid = using_AC254_lens(all_PSFs, patch_num)

    all_PSNR = []
    out_folder = 'finetune'

    for i in range(N_maxiter):

        #draw random image.
        img_idx = np.random.randint(len(imgs_H))

        img_H = cv2.imread(imgs_H[img_idx])
        img_H = np.pad(img_H, [(12, 12), (12, 12), (0, 0)])
        croppatch = imgpatch(img_H, 280, 280, 24)
        patches = croppatch.crop(img_H, 1)

        patch_E_list = []
        patch_L_list = []
        # patch_L,patch_H,patch_psf = draw_training_pair(img_H,PSF_grid,sf,patch_num,patch_size)

        for piece in range(len(patches)):
            patch_L, patch_H, patch_psf = draw_testing_pair(
                patches[piece], PSF_grid, sf, patch_num, patch_size)

            x = util.uint2single(patch_L)
            x = util.single2tensor4(x)
            x_gt = util.uint2single(patch_H)
            x_gt = util.single2tensor4(x_gt)

            k_local = []
            for h_ in range(patch_num[1]):
                for w_ in range(patch_num[0]):
                    k_local.append(util.single2tensor4(patch_psf[w_, h_]))
            k = torch.cat(k_local, dim=0)
            [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

            ab_patch = F.softplus(ab)
            ab_patch_v = []
            for h_ in range(patch_num[1]):
                for w_ in range(patch_num[0]):
                    ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
            ab_patch_v = torch.cat(ab_patch_v, dim=0)

            x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                             [patch_size[0], patch_size[1]],
                                             sf)

            patch_L = cv2.resize(patch_L,
                                 dsize=None,
                                 fx=sf,
                                 fy=sf,
                                 interpolation=cv2.INTER_NEAREST)
            patch_E = util.tensor2uint((x_E))

            patch_E_list.append(patch_E[np.newaxis, :])
            patch_L_list.append((patch_L[None, ...]))
            print(piece)
        img_E = croppatch.merge(patch_E_list)
        # croppatch_E = imgpatch(np.zeros_like(img_H), 256, 256, 0)
        # croppatch_E.crop(np.zeros_like(img_H),1)
        # img_E = croppatch_E.merge(patch_E_list)

        psnr, ssim = cal_psnrssim(img_E, img_H, 255)
        print(psnr)
        print(ssim)

        cv2.imwrite(
            os.path.join('./result', out_folder,
                         'resultE-{:04d}.png'.format(i + 1)), img_E)
        cv2.imwrite(
            os.path.join('./result', out_folder,
                         'resultH-{:04d}.png'.format(i + 1)), img_H)

        all_PSNR.append(psnr)

        #show = np.hstack((patch_H,patch_L,patch_E))

    np.savetxt(os.path.join('./result', out_folder, 'psnr.txt'), all_PSNR)
Exemplo n.º 9
0
def main():
    # 0. global config
    # scale factor
    sf = 4
    stage = 5
    patch_size = [32, 32]
    patch_num = [2, 2]

    # 1. local PSF
    # shape: gx,gy,kw,kw,3
    all_PSFs = load_kernels('./data')

    # 2. local model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=5,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                sf=sf,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    loaded_state_dict = torch.load('./logs/uabcnet_final.pth')
    # loaded_state_dict = torch.load('./data/uabcnet_finetune.pth')
    model.load_state_dict(loaded_state_dict, strict=True)
    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    # positional lambda, mu for HQS, set as free trainable parameters here.

    ab_pretrain = np.loadtxt('./logs/ab_pretrain.txt').reshape(
        (1, 1, 2 * stage, 3)).astype(np.float32)

    ab_buffer = np.ones((patch_num[0], patch_num[1], 2 * stage, 3),
                        dtype=np.float32)
    for xx in range(patch_num[0]):
        for yy in range(patch_num[1]):
            ab_buffer[xx, yy] = ab_pretrain[0, 0]

    ab = torch.tensor(ab_buffer, device=device, requires_grad=False)
    # ab = F.softplus(ab)

    # 3.load training data
    imgs_H = glob.glob('./DIV2K_train/*.png', recursive=True)
    imgs_H.sort()

    global_iter = 0
    N_maxiter = 1000

    PSF_grid = using_AC254_lens(all_PSFs, patch_num)

    all_PSNR = []

    for i in range(N_maxiter):

        # draw random image.
        img_idx = np.random.randint(len(imgs_H))

        img_H = cv2.imread(imgs_H[img_idx])

        patch_L, patch_H, patch_psf = draw_training_pair(
            img_H, PSF_grid, sf, patch_num, patch_size)

        x = util.uint2single(patch_L)
        x = util.single2tensor4(x)
        x_gt = util.uint2single(patch_H)
        x_gt = util.single2tensor4(x_gt)

        k_local = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                k_local.append(util.single2tensor4(patch_psf[w_, h_]))
        k = torch.cat(k_local, dim=0)
        [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]]

        ab_patch = F.softplus(ab)
        ab_patch_v = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                ab_patch_v.append(ab_patch[w_:w_ + 1, h_])
        ab_patch_v = torch.cat(ab_patch_v, dim=0)

        x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num,
                                         [patch_size[0], patch_size[1]], sf)

        patch_L = cv2.resize(patch_L,
                             dsize=None,
                             fx=sf,
                             fy=sf,
                             interpolation=cv2.INTER_NEAREST)
        patch_E = util.tensor2uint((x_E))

        psnr = cv2.PSNR(patch_E, patch_H)
        all_PSNR.append(psnr)

        show = np.hstack((patch_H, patch_L, patch_E))
        if i % 250 == 0:
            cv2.imwrite(
                os.path.join('./result', 'finetune',
                             'result-{:04d}.png'.format(i + 1)), show)

    cv2.imwrite(
        os.path.join('./result', 'finetune',
                     'result-{:04d}.png'.format(i + 1)), show)
    np.savetxt(os.path.join('./result', 'finetune', 'psnr.txt'), all_PSNR)