示例#1
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    #PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']
    PSF_grid = np.load('./data/Heide_PSF_plano_small.npz')['PSF']

    PSF_grid = PSF_grid.astype(np.float32)

    gx, gy = PSF_grid.shape[:2]
    for xx in range(gx):
        for yy in range(gy):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))

    # ----------------------------------------
    # load model
    # ----------------------------------------
    stage = 8
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=stage,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")

    model_code = 'iter800'
    loaded_state = torch.load(
        '/home/xiu/databag/deblur/models/plano/uabcnet_{}.pth'.format(
            model_code))
    #strip_state = strip_prefix_if_present(loaded_state,prefix="p.")
    model.load_state_dict(loaded_state, strict=True)

    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    for img_id in range(1, 237):
        #for img_id in range(1,12):
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video1-3/res/2_{:03d}.bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video/{:08d}.bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/databag/deblur/ICCV2021/suo_image/{}/AC254-075-A-ML-Zemax(ZMX).bmp'.format(img_id))
        #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/ResolutionChart/Reso.bmp')
        img_L = cv2.imread(
            '/home/xiu/databag/deblur/ICCV2021/MPI_data/drain/blurry.jpg')
        img_L = img_L.astype(np.float32)
        img_L = img_L[38:-39, 74:-74]
        img_L = cv2.resize(img_L, dsize=None, fx=0.5, fy=0.5)
        #img_L = np.pad(img_L,((1,1),(61,62),(0,0)),mode='edge')

        W, H = img_L.shape[:2]

        print(gx, gy)
        num_patch = [gx, gy]
        #positional alpha-beta parameters for HQS
        ab_numpy = np.loadtxt(
            '/home/xiu/databag/deblur/models/plano/ab_{}.txt'.format(
                model_code)).astype(np.float32).reshape(gx, gy, stage * 2, 3)

        ab = torch.tensor(ab_numpy, device=device, requires_grad=False)

        t0 = time.time()

        px_start = 0
        py_start = 0

        PSF_patch = PSF_grid[px_start:px_start + num_patch[0],
                             py_start:py_start + num_patch[1]]
        #block_expand = 1
        patch_L = img_L[px_start * W // gx:(px_start + num_patch[0]) * W // gx,
                        py_start * H // gy:(py_start + num_patch[1]) * H //
                        gy, :]

        p_W, p_H = patch_L.shape[:2]
        expand = max(PSF_grid.shape[2] // 2, p_W // 16)
        block_expand = expand
        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (p_W + block_expand * 2, p_H + block_expand * 2))
        #centralize
        patch_L_wrap = np.hstack((patch_L_wrap[:, -block_expand:, :],
                                  patch_L_wrap[:, :p_H + block_expand, :]))
        patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:, :, :],
                                  patch_L_wrap[:p_W + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        k_all = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                k_all.append(util.single2tensor4(PSF_patch[w_, h_]))
        k = torch.cat(k_all, dim=0)

        [x, k] = [el.to(device) for el in [x, k]]

        ab_patch = F.softplus(ab[px_start:px_start + num_patch[0],
                                 py_start:py_start + num_patch[1]])
        cd = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)

        x_E = model.forward_patchwise(x, k, cd, num_patch, [W // gx, H // gy])
        x_E = x_E[..., block_expand:block_expand + p_W,
                  block_expand:block_expand + p_H]

        patch_L = patch_L_wrap.astype(np.uint8)

        patch_E = util.tensor2uint(x_E)

        #patch_E_z = np.hstack((patch_E_all[::2]))
        #patch_E_x = np.hstack((patch_E_all[1::2]))

        #patch_E_show = np.vstack((patch_E_z,patch_E_x))
        #if block_expand>0:
        #	show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E))
        #else:
        #	show = np.hstack((patch_L,patch_E))

        #cv2.imshow('stage',patch_E_show)
        #cv2.imshow('HL',show)
        #cv2.imshow('RGB',rgb)
        #key = cv2.waitKey(-1)
        #if key==ord('n'):
        #	break

        t1 = time.time()

        print(t1 - t0)

        # print(i)
        xk = patch_E
        # #zk = zk.astype(np.uint8)
        xk = xk.astype(np.uint8)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/new_image/image/ours-{}.png'.format(img_id),xk)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/video_deblur/{:08d}.png'.format(img_id),xk)
        #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/cap_result/1_{:03d}.png'.format(img_id),xk)
        cv2.imshow('xx', xk)
        cv2.imshow('img_L', patch_L.astype(np.uint8))
        key = cv2.waitKey(-1)
        if key == ord('q'):
            break
示例#2
0
def main():
	# ----------------------------------------
	# load kernels
	# ----------------------------------------
	#PSF_grid = np.load('./data/Schuler_PSF01.npz')['PSF']
	#PSF_grid = np.load('./data/Schuler_PSF_facade.npz')['PSF']
	PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF']
	#PSF_grid = np.load('./data/Schuler_PSF03.npz')['PSF']
	#PSF_grid = np.load('./data/PSF.npz')['PSF']
	#print(PSF_grid.shape)
	
	PSF_grid = PSF_grid.astype(np.float32)

	gx,gy = PSF_grid.shape[:2]
	for xx in range(gx):
		for yy in range(gy):
			PSF_grid[xx,yy] = PSF_grid[xx,yy]/np.sum(PSF_grid[xx,yy],axis=(0,1))

	#PSF_grid = PSF_grid[:,1:-1,...]
	# ----------------------------------------
	# load model
	# ----------------------------------------
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
					nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")

	loaded_state = torch.load('./usrnet_ZEMAX.pth')
	#strip_state = strip_prefix_if_present(loaded_state,prefix="p.")
	model.load_state_dict(loaded_state, strict=True)

	model.eval()
	#model.train()
	for _, v in model.named_parameters():
		v.requires_grad = False
	#	v.requires_grad = False
	model = model.to(device)


	for img_id in range(100):
		img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video/{:08d}.bmp'.format(img_id))
		img_L = img_L.astype(np.float32)

		img_E = np.zeros_like(img_L)

		weight_E = np.zeros_like(img_L)

		patch_size = 2*128
		num_patch = 2
		p_size = patch_size//num_patch
		expand = PSF_grid.shape[2]

		ab_numpy = np.loadtxt('ab_ZEMAX.txt').astype(np.float32).reshape(6,8,16,3)
		ab = torch.tensor(ab_numpy,device=device,requires_grad=False)

		#save img_L

		t0 = time.time()
		#while running:
		for px_start in range(0,6-2+1,2):
			for py_start in range(0,8-2+1,2):

				PSF_patch = PSF_grid[px_start:px_start+num_patch,py_start:py_start+num_patch]

				patch_L = img_L[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:]
				block_expand = expand
				#block_expand = 1
				if block_expand > 0:
					patch_L_wrap = util_deblur.wrap_boundary_liu(patch_L,(patch_size+block_expand*2,patch_size+block_expand*2))
					#centralize
					patch_L_wrap = np.hstack((patch_L_wrap[:,-block_expand:,:],patch_L_wrap[:,:patch_size+block_expand,:]))
					patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:,:,:],patch_L_wrap[:patch_size+block_expand,:,:]))
				else:
					patch_L_wrap = patch_L
				if block_expand>0:
					x = util.uint2single(patch_L_wrap)
				else:
					x = util.uint2single(patch_L)

				x = util.single2tensor4(x)

				# x_blocky = torch.cat(torch.chunk(x,num_patch,dim=2),dim=0)
				# x_blocky = torch.cat(torch.chunk(x_blocky,num_patch,dim=3),dim=0)

				k_all = []
				for w_ in range(num_patch):
					for h_ in range(num_patch):
						k_all.append(util.single2tensor4(PSF_patch[h_,w_]))
				k = torch.cat(k_all,dim=0)

				[x,k] = [el.to(device) for el in [x,k]]

				cd = F.softplus(ab[px_start:px_start+num_patch,py_start:py_start+num_patch])
				cd = cd.view(num_patch**2,2*8,3)

				x_E = model.forward_patchwise(x,k,cd,[num_patch,num_patch],[patch_size//num_patch,patch_size//num_patch])

				patch_L = patch_L_wrap.astype(np.uint8)

				patch_E = util.tensor2uint(x_E)
				patch_E_all = [util.tensor2uint(pp) for pp in x_E]

				#patch_E_z = np.hstack((patch_E_all[::2]))
				#patch_E_x = np.hstack((patch_E_all[1::2]))

				#patch_E_show = np.vstack((patch_E_z,patch_E_x))
				#if block_expand>0:
				#	show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E))
				#else:
				#	show = np.hstack((patch_L,patch_E))

		
				#get kernel
				for i in range(8):
					img_E_deconv[i][px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += patch_E_all[-2][expand:-expand,expand:-expand]
					img_E_denoise[i][px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += patch_E_all[-1][expand:-expand,expand:-expand]
				weight_E[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += 1.0

				#cv2.imshow('stage',patch_E_show)
				#cv2.imshow('HL',show)
				#cv2.imshow('RGB',rgb)
				#key = cv2.waitKey(-1)
				#if key==ord('n'):
				#	break

		t1 = time.time()

		print(t1-t0)
		img_E = img_E/weight_E
		img_E_deconv = [pp/weight_E for pp in img_E_deconv]
		img_E_denoise = [pp/weight_E for pp in img_E_denoise]

		# print(i)
		xk = img_E_denoise[-1]
		# #zk = zk.astype(np.uint8)
		xk = xk.astype(np.uint8)
		#cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/video_deblur/{:08d}.png'.format(img_id),xk)
		cv2.imshow('xx',xk)
		cv2.imshow('img_L',img_L.astype(np.uint8))
		cv2.waitKey(-1)
示例#3
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']
    PSF_grid = PSF_grid.astype(np.float32)
    gx, gy = PSF_grid.shape[:2]

    k_tensor = []
    for yy in range(gy):
        for xx in range(gx):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))
            k_tensor.append(util.single2tensor4(PSF_grid[xx, yy]))

    k_tensor = torch.cat(k_tensor, dim=0)
    inv_weight = util_deblur.get_inv_spatial_weight(k_tensor)

    # ----------------------------------------
    # load model
    # ----------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.proj.load_state_dict(torch.load('./data/usrnet_pretrain.pth'),
                               strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    # ----------------------------------------
    # load training data
    # ----------------------------------------
    imgs = glob.glob('./DIV2K_train/*.png', recursive=True)
    imgs.sort()

    # ----------------------------------------
    # positional lambda\mu for HQS
    # ----------------------------------------
    stage = 8
    ab_buffer = np.ones((gx, gy, 2 * stage, 3), dtype=np.float32) * 0.1
    #ab_buffer[:,:,0,:] = 0.01
    ab = torch.tensor(ab_buffer, device=device, requires_grad=True)

    # ----------------------------------------
    # build optimizer
    # ----------------------------------------
    params = []
    params += [{"params": [ab], "lr": 0.0005}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 0.0001}]
    optimizer = torch.optim.Adam(params, lr=0.0001, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1000,
                                                gamma=0.9)

    patch_size = [128, 128]
    expand = PSF_grid.shape[2] // 2
    patch_num = [2, 2]

    global_iter = 0

    running = True

    while running:
        #alpha.beta
        img_idx = np.random.randint(len(imgs))
        img = imgs[img_idx]
        img_H = cv2.imread(img)
        w, h = img_H.shape[:2]

        #focus on the edges

        mode = np.random.randint(5)
        px_start = np.random.randint(0, gx - patch_num[0] + 1)
        py_start = np.random.randint(0, gy - patch_num[1] + 1)
        if mode == 0:
            px_start = 0
        if mode == 1:
            px_start = gx - patch_num[0]
        if mode == 2:
            py_start = 0
        if mode == 3:
            py_start = gy - patch_num[1]

        x_start = np.random.randint(
            0, w - patch_size[0] * patch_num[0] - expand * 2 + 1)
        y_start = np.random.randint(
            0, h - patch_size[1] * patch_num[1] - expand * 2 + 1)
        PSF_patch = PSF_grid[px_start:px_start + patch_num[0],
                             py_start:py_start + patch_num[1]]

        patch_H = img_H[x_start:x_start+patch_size[0]*patch_num[0]+expand*2,\
         y_start:y_start+patch_size[1]*patch_num[1]+expand*2]
        patch_L = util_deblur.blockConv2d(patch_H, PSF_patch, expand)

        block_expand = max(patch_size[0] // 8, expand)

        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (patch_size[0] * patch_num[0] + block_expand * 2,
                      patch_size[1] * patch_num[1] + block_expand * 2))
        patch_L_wrap = np.hstack(
            (patch_L_wrap[:, -block_expand:, :],
             patch_L_wrap[:, :patch_size[1] * patch_num[1] + block_expand, :]))
        patch_L_wrap = np.vstack(
            (patch_L_wrap[-block_expand:, :, :],
             patch_L_wrap[:patch_size[0] * patch_num[0] + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        x_gt = util.uint2single(patch_H[expand:-expand, expand:-expand])
        x_gt = util.single2tensor4(x_gt)
        inv_weight_patch = torch.ones_like(x_gt)

        k_local = []

        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                inv_weight_patch[0, 0,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             0]
                inv_weight_patch[0, 1,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             1]
                inv_weight_patch[0, 2,
                                 w_ * patch_size[0]:(w_ + 1) * patch_size[0],
                                 h_ * patch_size[1]:(h_ + 1) *
                                 patch_size[1]] = inv_weight[w_ +
                                                             h_ * patch_num[0],
                                                             2]
                k_local.append(k_tensor[w_ + h_ * patch_num[0]:w_ +
                                        h_ * patch_num[0] + 1])

        k = torch.cat(k_local, dim=0)
        [x, x_gt, k, inv_weight_patch
         ] = [el.to(device) for el in [x, x_gt, k, inv_weight_patch]]
        ab_patch = F.softplus(ab[px_start:px_start + patch_num[0],
                                 py_start:py_start + patch_num[1]])
        cd = []
        for h_ in range(patch_num[1]):
            for w_ in range(patch_num[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)
        x_E = model.forward_patchwise(x, k, cd, patch_num, patch_size)

        predict = x_E[...,block_expand:block_expand+patch_size[0]*patch_num[0],\
         block_expand:block_expand+patch_size[1]*patch_num[1]]
        loss = F.l1_loss(predict.div(inv_weight_patch),
                         x_gt.div(inv_weight_patch))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        print('iter:{},loss {}'.format(global_iter + 1, loss.item()))

        patch_L = patch_L_wrap.astype(np.uint8)
        patch_E = util.tensor2uint(x_E)[block_expand:-block_expand,
                                        block_expand:-block_expand]

        show = np.hstack((patch_H[expand:-expand, expand:-expand],
                          patch_L[block_expand:-block_expand,
                                  block_expand:-block_expand], patch_E))

        cv2.imshow('HL', show)
        key = cv2.waitKey(1)

        global_iter += 1

        #change the save period
        if global_iter % 100 == 0:
            ab_numpy = ab.detach().cpu().numpy().flatten()
            torch.save(
                model.state_dict(),
                './ZEMAX_model/usrnet_ZEMAX_iter{}.pth'.format(global_iter))
            np.savetxt('./ZEMAX_model/ab_ZEMAX_iter{}.txt'.format(global_iter),
                       ab_numpy)
        if key == ord('q'):
            running = False
            break
    ab_numpy = ab.detach().cpu().numpy().flatten()
    torch.save(model.state_dict(), './ZEMAX_model/usrnet_ZEMAX.pth')
    np.savetxt('./ZEMAX_model/ab_ZEMAX.txt', ab_numpy)
示例#4
0
文件: predictor.py 项目: leehsiu/UABC
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF']

    PSF_grid = PSF_grid.astype(np.float32)

    gx, gy = PSF_grid.shape[:2]
    for xx in range(gx):
        for yy in range(gy):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))

    # ----------------------------------------
    # load model
    # ----------------------------------------
    stage = 8
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=stage,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=2,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")

    model_code = 'iter17000'
    loaded_state = torch.load(
        '/home/xiu/databag/deblur/models/ZEMAX/uabcnet_{}.pth'.format(
            model_code))
    model.load_state_dict(loaded_state, strict=True)

    model.eval()
    for _, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    img_names = glob.glob(
        '/home/xiu/databag/deblur/ICCV2021/suo_image/*/AC254-075-A-ML-Zemax(ZMX).bmp'
    )
    img_names.sort()
    for img_id, img_name in enumerate(img_names):
        img_L = cv2.imread(img_name)
        img_L = img_L.astype(np.float32)
        W, H = img_L.shape[:2]
        num_patch = [6, 8]
        #positional alpha-beta parameters for HQS
        ab_numpy = np.loadtxt(
            '/home/xiu/databag/deblur/models/ZEMAX/ab_{}.txt'.format(
                model_code)).astype(np.float32).reshape(gx, gy, stage * 2, 3)
        ab = torch.tensor(ab_numpy, device=device, requires_grad=False)

        #save img_L

        t0 = time.time()

        px_start = 0
        py_start = 0

        PSF_patch = PSF_grid[px_start:px_start + num_patch[0],
                             py_start:py_start + num_patch[1]]
        #block_expand = 1
        patch_L = img_L[px_start * W // gx:(px_start + num_patch[0]) * W // gx,
                        py_start * H // gy:(py_start + num_patch[1]) * H //
                        gy, :]

        p_W, p_H = patch_L.shape[:2]
        expand = max(PSF_grid.shape[2] // 2, p_W // 16)
        block_expand = expand
        patch_L_wrap = util_deblur.wrap_boundary_liu(
            patch_L, (p_W + block_expand * 2, p_H + block_expand * 2))
        #centralize
        patch_L_wrap = np.hstack((patch_L_wrap[:, -block_expand:, :],
                                  patch_L_wrap[:, :p_H + block_expand, :]))
        patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:, :, :],
                                  patch_L_wrap[:p_W + block_expand, :, :]))
        x = util.uint2single(patch_L_wrap)
        x = util.single2tensor4(x)

        k_all = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                k_all.append(util.single2tensor4(PSF_patch[w_, h_]))
        k = torch.cat(k_all, dim=0)

        [x, k] = [el.to(device) for el in [x, k]]

        ab_patch = F.softplus(ab[px_start:px_start + num_patch[0],
                                 py_start:py_start + num_patch[1]])
        cd = []
        for h_ in range(num_patch[1]):
            for w_ in range(num_patch[0]):
                cd.append(ab_patch[w_:w_ + 1, h_])
        cd = torch.cat(cd, dim=0)

        x_E = model.forward_patchwise(x, k, cd, num_patch, [W // gx, H // gy])
        x_E = x_E[..., block_expand:block_expand + p_W,
                  block_expand:block_expand + p_H]

        patch_L = patch_L_wrap.astype(np.uint8)

        patch_E = util.tensor2uint(x_E)

        t1 = time.time()

        print('[{}/{}]: {} s per frame'.format(img_id, len(img_names),
                                               t1 - t0))

        xk = patch_E
        xk = xk.astype(np.uint8)

        cv2.imshow('res', xk)
        cv2.imshow('input', patch_L.astype(np.uint8))

        key = cv2.waitKey(-1)
        if key == ord('q'):
            break
示例#5
0
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------
    model_name = 'usrnet'  # 'usrgan' | 'usrnet' | 'usrgan_tiny' | 'usrnet_tiny'
    testset_name = 'set5'  # test set,  'set5' | 'srbsd68'
    need_degradation = True  # default: True
    sf = 4  # scale factor, only from {2, 3, 4}
    show_img = False  # default: False
    save_L = True  # save LR image
    save_E = True  # save estimated image

    # load approximated bicubic kernels
    #kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernels_bicubicx234.mat'))['kernels']
    kernels = loadmat(os.path.join('kernels',
                                   'kernels_bicubicx234.mat'))['kernels']
    kernel = kernels[0, sf - 2].astype(np.float64)
    kernel = util.single2tensor4(kernel[..., np.newaxis])

    task_current = 'sr'  # fixed, 'sr' for super-resolution
    n_channels = 3  # fixed, 3 for color image
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'  # fixed
    results = 'results'  # fixed
    noise_level_img = 0  # fixed: 0, noise level for LR image
    noise_level_model = noise_level_img  # fixed, noise level of model, default 0
    result_name = testset_name + '_' + model_name + '_bicubic'
    border = sf if task_current == 'sr' else 0  # shave boader to calculate PSNR and SSIM
    model_path = os.path.join(model_pool, model_name + '.pth')

    # ----------------------------------------
    # L_path, E_path, H_path
    # ----------------------------------------
    L_path = os.path.join(
        testsets, testset_name)  # L_path, fixed, for Low-quality images
    H_path = L_path  # H_path, 'None' | L_path, for High-quality images
    E_path = os.path.join(results,
                          result_name)  # E_path, fixed, for Estimated images
    util.mkdir(E_path)

    if H_path == L_path:
        need_degradation = True
    logger_name = result_name
    utils_logger.logger_info(logger_name,
                             log_path=os.path.join(E_path,
                                                   logger_name + '.log'))
    logger = logging.getLogger(logger_name)

    need_H = True if H_path is not None else False
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------
    from models.network_usrnet import USRNet as net  # for pytorch version <= 1.7.1
    # from models.network_usrnet_v1 import USRNet as net  # for pytorch version >=1.8.1

    if 'tiny' in model_name:
        model = net(n_iter=6,
                    h_nc=32,
                    in_nc=4,
                    out_nc=3,
                    nc=[16, 32, 64, 64],
                    nb=2,
                    act_mode="R",
                    downsample_mode='strideconv',
                    upsample_mode="convtranspose")
    else:
        model = net(n_iter=8,
                    h_nc=64,
                    in_nc=4,
                    out_nc=3,
                    nc=[64, 128, 256, 512],
                    nb=2,
                    act_mode="R",
                    downsample_mode='strideconv',
                    upsample_mode="convtranspose")

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for key, v in model.named_parameters():
        v.requires_grad = False

    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    logger.info('model_name:{}, image sigma:{}'.format(model_name,
                                                       noise_level_img))
    logger.info(L_path)
    L_paths = util.get_image_paths(L_path)
    H_paths = util.get_image_paths(H_path) if need_H else None

    for idx, img in enumerate(L_paths):

        # ------------------------------------
        # (1) img_L
        # ------------------------------------
        img_name, ext = os.path.splitext(os.path.basename(img))
        logger.info('{:->4d}--> {:>10s}'.format(idx + 1, img_name + ext))
        img_L = util.imread_uint(img, n_channels=n_channels)
        img_L = util.uint2single(img_L)

        # degradation process, bicubic downsampling
        if need_degradation:
            img_L = util.modcrop(img_L, sf)
            img_L = util.imresize_np(img_L, 1 / sf)

            # img_L = util.uint2single(util.single2uint(img_L))
            # np.random.seed(seed=0)  # for reproducibility
            # img_L += np.random.normal(0, noise_level_img/255., img_L.shape)

        w, h = img_L.shape[:2]

        if save_L:
            util.imsave(
                util.single2uint(img_L),
                os.path.join(E_path, img_name + '_LR_x' + str(sf) + '.png'))

        img = cv2.resize(img_L, (sf * h, sf * w),
                         interpolation=cv2.INTER_NEAREST)
        img = utils_deblur.wrap_boundary_liu(img, [
            int(np.ceil(sf * w / 8 + 2) * 8),
            int(np.ceil(sf * h / 8 + 2) * 8)
        ])
        img_wrap = sr.downsample_np(img, sf, center=False)
        img_wrap[:w, :h, :] = img_L
        img_L = img_wrap

        util.imshow(util.single2uint(img_L),
                    title='LR image with noise level {}'.format(
                        noise_level_img)) if show_img else None

        img_L = util.single2tensor4(img_L)
        img_L = img_L.to(device)

        # ------------------------------------
        # (2) img_E
        # ------------------------------------
        sigma = torch.tensor(noise_level_model).float().view([1, 1, 1, 1])
        [img_L, kernel,
         sigma] = [el.to(device) for el in [img_L, kernel, sigma]]

        img_E = model(img_L, kernel, sf, sigma)

        img_E = util.tensor2uint(img_E)
        img_E = img_E[:sf * w, :sf * h, :]

        if need_H:

            # --------------------------------
            # (3) img_H
            # --------------------------------
            img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
            img_H = img_H.squeeze()
            img_H = util.modcrop(img_H, sf)

            # --------------------------------
            # PSNR and SSIM
            # --------------------------------
            psnr = util.calculate_psnr(img_E, img_H, border=border)
            ssim = util.calculate_ssim(img_E, img_H, border=border)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(
                img_name + ext, psnr, ssim))
            util.imshow(np.concatenate([img_E, img_H], axis=1),
                        title='Recovered / Ground-truth') if show_img else None

            if np.ndim(img_H) == 3:  # RGB image
                img_E_y = util.rgb2ycbcr(img_E, only_y=True)
                img_H_y = util.rgb2ycbcr(img_H, only_y=True)
                psnr_y = util.calculate_psnr(img_E_y, img_H_y, border=border)
                ssim_y = util.calculate_ssim(img_E_y, img_H_y, border=border)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)

        # ------------------------------------
        # save results
        # ------------------------------------
        if save_E:
            util.imsave(
                img_E,
                os.path.join(
                    E_path,
                    img_name + '_x' + str(sf) + '_' + model_name + '.png'))

    if need_H:
        ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
        ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
        logger.info(
            'Average PSNR/SSIM(RGB) - {} - x{} --PSNR: {:.2f} dB; SSIM: {:.4f}'
            .format(result_name, sf, ave_psnr, ave_ssim))
        if np.ndim(img_H) == 3:
            ave_psnr_y = sum(test_results['psnr_y']) / len(
                test_results['psnr_y'])
            ave_ssim_y = sum(test_results['ssim_y']) / len(
                test_results['ssim_y'])
            logger.info(
                'Average PSNR/SSIM( Y ) - {} - x{} - PSNR: {:.2f} dB; SSIM: {:.4f}'
                .format(result_name, sf, ave_psnr_y, ave_ssim_y))
def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------
    model_name = 'usrnet'      # 'usrgan' | 'usrnet' | 'usrgan_tiny' | 'usrnet_tiny'
    testset_name = 'set_real'  # test set,  'set_real'
    test_image = 'chip.png'    # 'chip.png', 'comic.png'
    #test_image = 'comic.png'

    sf = 4                     # scale factor, only from {1, 2, 3, 4}
    show_img = False           # default: False
    save_E = True              # save estimated image
    save_LE = True             # save zoomed LR, Estimated images

    # ----------------------------------------
    # set noise level and kernel
    # ----------------------------------------
    if 'chip' in test_image:
        noise_level_img = 15       # noise level for LR image, 15 for chip
        kernel_width_default_x1234 = [0.6, 0.9, 1.7, 2.2] # Gaussian kernel widths for x1, x2, x3, x4
    else:
        noise_level_img = 2       # noise level for LR image, 0.5~3 for clean images
        kernel_width_default_x1234 = [0.4, 0.7, 1.5, 2.0] # default Gaussian kernel widths of clean/sharp images for x1, x2, x3, x4

    noise_level_model = noise_level_img/255.  # noise level of model
    kernel_width = kernel_width_default_x1234[sf-1]

    # set your own kernel width
    # kernel_width = 2.2

    k = utils_deblur.fspecial('gaussian', 25, kernel_width)
    k = sr.shift_pixel(k, sf)  # shift the kernel
    k /= np.sum(k)
    util.surf(k) if show_img else None
    # scio.savemat('kernel_realapplication.mat', {'kernel':k})

    # load approximated bicubic kernels
    #kernels = hdf5storage.loadmat(os.path.join('kernels', 'kernel_bicubicx234.mat'))['kernels']
#    kernels = loadmat(os.path.join('kernels', 'kernel_bicubicx234.mat'))['kernels']
#    kernel = kernels[0, sf-2].astype(np.float64)

    kernel = util.single2tensor4(k[..., np.newaxis])


    n_channels = 1 if 'gray' in  model_name else 3  # 3 for color image, 1 for grayscale image
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'     # fixed
    results = 'results'       # fixed
    result_name = testset_name + '_' + model_name
    model_path = os.path.join(model_pool, model_name+'.pth')

    # ----------------------------------------
    # L_path, E_path
    # ----------------------------------------
    L_path = os.path.join(testsets, testset_name) # L_path, fixed, for Low-quality images
    E_path = os.path.join(results, result_name)   # E_path, fixed, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
    logger = logging.getLogger(logger_name)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------
    if 'tiny' in model_name:
        model = net(n_iter=6, h_nc=32, in_nc=4, out_nc=3, nc=[16, 32, 64, 64],
                    nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
    else:
        model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
                    nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")

    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for key, v in model.named_parameters():
        v.requires_grad = False

    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))

    logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img))
    logger.info(L_path)

    img = os.path.join(L_path, test_image)
    # ------------------------------------
    # (1) img_L
    # ------------------------------------
    img_name, ext = os.path.splitext(os.path.basename(img))
    img_L = util.imread_uint(img, n_channels=n_channels)
    img_L = util.uint2single(img_L)

    util.imshow(img_L) if show_img else None
    w, h = img_L.shape[:2]
    logger.info('{:>10s}--> ({:>4d}x{:<4d})'.format(img_name+ext, w, h))

    # boundary handling
    boarder = 8     # default setting for kernel size 25x25
    img = cv2.resize(img_L, (sf*h, sf*w), interpolation=cv2.INTER_NEAREST)
    img = utils_deblur.wrap_boundary_liu(img, [int(np.ceil(sf*w/boarder+2)*boarder), int(np.ceil(sf*h/boarder+2)*boarder)])
    img_wrap = sr.downsample_np(img, sf, center=False)
    img_wrap[:w, :h, :] = img_L
    img_L = img_wrap

    util.imshow(util.single2uint(img_L), title='LR image with noise level {}'.format(noise_level_img)) if show_img else None

    img_L = util.single2tensor4(img_L)
    img_L = img_L.to(device)

    # ------------------------------------
    # (2) img_E
    # ------------------------------------
    sigma = torch.tensor(noise_level_model).float().view([1, 1, 1, 1])
    [img_L, kernel, sigma] = [el.to(device) for el in [img_L, kernel, sigma]]

    img_E = model(img_L, kernel, sf, sigma)

    img_E = util.tensor2uint(img_E)[:sf*w, :sf*h, ...]

    if save_E:
        util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'.png'))

    # --------------------------------
    # (3) save img_LE
    # --------------------------------
    if save_LE:
        k_v = k/np.max(k)*1.2
        k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, 3]))
        k_factor = 3
        k_v = cv2.resize(k_v, (k_factor*k_v.shape[1], k_factor*k_v.shape[0]), interpolation=cv2.INTER_NEAREST)
        img_L = util.tensor2uint(img_L)[:w, :h, ...]
        img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST)
        img_I[:k_v.shape[0], :k_v.shape[1], :] = k_v
        util.imshow(np.concatenate([img_I, img_E], axis=1), title='LR / Recovered') if show_img else None
        util.imsave(np.concatenate([img_I, img_E], axis=1), os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'_LE.png'))
示例#7
0
def main():

    # --------------------------------
    # let's start!
    # --------------------------------
    utils_logger.logger_info('test_dpsr_real', log_path='test_dpsr_real.log')
    logger = logging.getLogger('test_dpsr_real')
    global arg
    arg = parser.parse_args()
    # basic setting
    # ================================================
    sf = arg.sf
    show_img = False
    noise_level_img = 8. / 255.
    #testsets = '/home/share2/wutong/DPSR/testsets/test/'

    #im = '0000115_01031_d_0000082.jpg'  # chip.png colour.png

    # if 'chip' in im:
    #   noise_level_img = 8./255.
    # elif 'colour' in im:
    #noise_level_img = 0.5/255.

    use_srganplus = False
    if use_srganplus and sf == 4:
        model_prefix = 'DPSRGAN'
        save_suffix = 'dpsrgan'
    else:
        model_prefix = 'DPSR'
        save_suffix = 'dpsr'

    model_path = os.path.join('DPSR_models', model_prefix + 'x%01d.pth' % (sf))

    iter_num = 15  # number of iterations
    n_channels = 3  # only color images, fixed

    # ================================================

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --------------------------------
    # (1) load trained model
    # --------------------------------

    model = SRResNet(in_nc=4,
                     out_nc=3,
                     nc=96,
                     nb=16,
                     upscale=sf,
                     act_mode='R',
                     upsample_mode='pixelshuffle')
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)
    logger.info('Model path {:s}. Testing...'.format(model_path))

    # --------------------------------
    # (2) L_folder, E_folder
    # --------------------------------
    # --1--> L_folder, folder of Low-quality images
    L_folder = os.path.join(arg.load)  # L: Low quality

    # --2--> E_folder, folder of Estimated images
    E_folder = os.path.join(arg.save)
    util.mkdir(E_folder)

    logger.info(L_folder)

    # for im in os.listdir(os.path.join(L_folder)):
    #   if (im.endswith('.jpg') or im.endswith('.bmp') or im.endswith('.png')) and 'kernel' not in im:

    # --------------------------------
    # (3) load low-resolution image
    # --------------------------------
    img_list = os.listdir(L_folder)
    for im in img_list:
        img_path, ext = os.path.splitext(im)
        img_name = img_path.split('/')[-1]
        img = util.imread_uint(os.path.join(L_folder, im),
                               n_channels=n_channels)
        h, w = img.shape[:2]
        util.imshow(img, title='Low-resolution image') if show_img else None
        img = util.unit2single(img)

        # --------------------------------
        # (4) load blur kernel
        # --------------------------------
        # if os.path.exists(os.path.join(L_folder, img_name+'_kernel.mat')):
        # k = loadmat(os.path.join(L_folder, img_name+'.mat'))['kernel']
        #  k = k.astype(np.float64)
        #  k /= k.sum()
        # elif os.path.exists(os.path.join(L_folder, img_name+'_kernel.png')):
        #   k = cv2.imread(os.path.join(L_folder, img_name+'_kernel.png'), 0)
        #    k = np.float64(k)  # float64 !
        #    k /= k.sum()
        #else:
        k = utils_deblur.fspecial('gaussian', 5, 0.25)
        iter_num = 5

        # --------------------------------
        # (5) handle boundary
        # --------------------------------
        img = utils_deblur.wrap_boundary_liu(
            img,
            utils_deblur.opt_fft_size(
                [img.shape[0] + k.shape[0] + 1,
                 img.shape[1] + k.shape[1] + 1]))

        # --------------------------------
        # (6) get upperleft, denominator
        # --------------------------------
        upperleft, denominator = utils_deblur.get_uperleft_denominator(img, k)

        # --------------------------------
        # (7) get rhos and sigmas
        # --------------------------------
        rhos, sigmas = utils_deblur.get_rho_sigma(sigma=max(
            0.255 / 255.0, noise_level_img),
                                                  iter_num=iter_num)

        # --------------------------------
        # (8) main iteration
        # --------------------------------
        z = img
        rhos = np.float32(rhos)
        sigmas = np.float32(sigmas)

        for i in range(iter_num):

            logger.info('Iter: {:->4d}--> {}'.format(i + 1, im))
            # --------------------------------
            # step 1, Eq. (9) // FFT
            # --------------------------------
            rho = rhos[i]
            if i != 0:
                z = util.imresize_np(z, 1 / sf, True)

            z = np.real(
                np.fft.ifft2((upperleft + rho * np.fft.fft2(z, axes=(0, 1))) /
                             (denominator + rho),
                             axes=(0, 1)))

            # --------------------------------
            # step 2, Eq. (12) // super-resolver
            # --------------------------------
            sigma = torch.from_numpy(np.array(sigmas[i]))
            img_L = util.single2tensor4(z)

            noise_level_map = torch.ones((1, 1, img_L.size(2), img_L.size(3)),
                                         dtype=torch.float).mul_(sigma)
            img_L = torch.cat((img_L, noise_level_map), dim=1)
            img_L = img_L.to(device)
            # with torch.no_grad():
            z = model(img_L)
            z = util.tensor2single(z)

        # --------------------------------
        # (9) img_E
        # --------------------------------
        img_E = util.single2uint(
            z[:h * sf, :w * sf])  # np.uint8((z[:h*sf, :w*sf] * 255.0).round())

        logger.info('saving: sf = {}, {}.'.format(
            sf, img_name + '_x{}'.format(sf) + ext))
        util.imsave(img_E, os.path.join(E_folder, img_name + ext))

        util.imshow(img_E, title='Recovered image') if show_img else None