Пример #1
0
def main():
	root  = '/home/xiu/databag/deblur/pretrain/patchwise'
	logger = SummaryWriter(root)
	# ----------------------------------------
	# load kernels
	# ----------------------------------------
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

	PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF']
	PSF_grid = PSF_grid.astype(np.float32)
	kw,kh = PSF_grid.shape[:2]
	for w_ in range(kw):
		for h_ in range(kh):
			PSF_grid[w_,h_] = PSF_grid[w_,h_]/np.sum(PSF_grid[w_,h_],axis=(0,1))
	# ----------------------------------------
	# build
	# ----------------------------------------
	model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
					nb=3, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
	model.train()
	for _, v in model.named_parameters():
		v.requires_grad = True
	model = model.to(device)

	train_imgs = glob.glob('/home/xiu/databag/deblur/images/*/**.png',recursive=True)
	train_imgs.sort()

	#global model. 3x4
	n_epoch = 200
	n_stage = 5
	n_batch = 3
	w_patch = 128
	n_patch = 2
	#also 3x4,but different strategy
	
	ab_buffer = np.ones((1,1,2*n_stage+1,3),dtype=np.float32)*0.1
	ab_param = torch.tensor(ab_buffer,device=device,requires_grad=True)

	params = []
	params += [{"params":[ab_param],"lr":1e-4}]
	for key,value in model.named_parameters():
		params += [{"params":[value],"lr":1e-4}]
	optimizer = torch.optim.Adam(params,lr=1e-4)

	img_index = np.arange(len(train_imgs))

	global_iter = 0
	expand = PSF_grid.shape[2]//2*2
	
	#using this small PSF only.
	PSF = PSF_grid[1,2]
	
	for epoch in range(n_epoch):
		np.random.shuffle(img_index)
		for iteration in range(len(train_imgs)//n_batch):
			imgs = []
			for ii in range(n_batch):
				imgs.append(cv2.imread(train_imgs[img_index[iteration*n_batch+ii]]))
			global_iter += 1

			vis_H = []
			vis_L = []
			vis_E = []

			x = []
			y = []
			for img in imgs:
				w,h = img.shape[:2]
				for _ in range(n_patch*n_patch):
					x_start = np.random.randint(0,w-w_patch-expand*2+1)
					y_start = np.random.randint(0,h-w_patch-expand*2+1)
					patch_H = img[x_start:x_start+w_patch+expand*2,y_start:y_start+w_patch+expand*2]
					patch_L = util_deblur.uniformConv2d(patch_H,PSF)

					vis_H.append(patch_H[expand:-expand,expand:-expand])
					vis_L.append(patch_L[expand//2:-expand//2,expand//2:-expand//2])
					x_L = util.uint2single(patch_L)
					x_L = util.single2tensor4(x_L)
					x_gt = util.uint2single(patch_H[expand:-expand,expand:-expand])
					x_gt = util.single2tensor4(x_gt)
					y.append(x_L)
					x.append(x_gt)
			
			ab = F.softplus(ab_param)
			loss = 0
			k = util.single2tensor4(PSF)
			k = k.to(device)
			for i in range(n_batch*n_patch*n_patch):
				yy = y[i].to(device)
				xx = x[i].to(device)
				#xE = model.forward_patchdeconv(yy,kk,ab[0],[1,1],w_patch)
				xE = model.forward_globaldeconv(yy,k,ab[0],w_patch)
				loss += F.l1_loss(xE[-2],xx)
				patch_E = xE[-2].clone().detach().cpu()
				vis_E.append(util.tensor2uint(patch_E))
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			if global_iter % 10 ==0 :
				print('iter {}: loss{}.'.format(global_iter,loss.item()))
				logger.add_scalar('train-loss',loss.item(),global_iter)
				for i in range(n_batch):
					show1 = np.hstack((vis_H[i],vis_L[i],vis_E[i]))
					logger.add_image('show-{}'.format(i),util.uint2tensor3(show1[:,:,::-1]))
					logger.flush()	
		ab_numpy = ab.detach().cpu().numpy()[:,:,0,0]
		ab_numpy = ab_numpy.flatten()
		torch.save(model.state_dict(),os.path.join(root,'usrnet_patchwise1x2_epoch{}.pth'.format(epoch)))
		np.savetxt(os.path.join(root,'patchwise_ours.txt'),ab_numpy)
Пример #2
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    all_PSFs = load_kernels('./data')
    # ----------------------------------------
    # build
    # ----------------------------------------
    model = net(n_iter=8,
                h_nc=64,
                in_nc=3,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=3,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    train_imgs = glob.glob('/home/xiu/databag/deblur/images/*/**.png',
                           recursive=True)
    train_imgs.sort()

    n_stage = 5
    n_batch = 9
    n_epoch = 200

    w_patch = 128
    n_patch = 1

    ab_buffer = np.ones((n_batch, n_patch * n_patch, 2 * n_stage + 1, 3),
                        dtype=np.float32) * 0.1
    ab_param = torch.tensor(ab_buffer, device=device, requires_grad=False)

    params = []
    #params += [{"params":[ab_param],"lr":1e-4}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 1e-4}]
    optimizer = torch.optim.Adam(params, lr=1e-4)

    img_index = np.arange(len(train_imgs))

    global_iter = 0

    PSFs = []
    for i in range(n_batch):
        #	if i % 2==0:
        PSFs.append(all_PSFs[0][0:1, 0:1])
    #	else:
    #		PSFs.append(rand_kernels(n_patch))
    expands = []
    for i in range(n_batch):
        expands.append(PSFs[i].shape[2] // 2)

    for epoch in range(n_epoch):
        np.random.shuffle(img_index)
        for iteration in range(len(train_imgs) // n_batch):
            imgs = []
            for ii in range(n_batch):
                imgs.append(
                    cv2.imread(train_imgs[img_index[iteration * n_batch +
                                                    ii]]))

            global_iter += 1
            if global_iter % 100 == 0:
                PSFs = []
                for i in range(n_batch):
                    PSFs.append(all_PSFs[0][0:1, 0:1])
                expands = []
                for i in range(n_batch):
                    expands.append(PSFs[i].shape[2] // 2)

                #get new kernel.
            x = []
            y = []
            for img, expand, PSF in zip(imgs, expands, PSFs):
                w, h = img.shape[:2]
                x_start = np.random.randint(
                    0, w - w_patch * n_patch - expand * 2 + 1)
                y_start = np.random.randint(
                    0, h - w_patch * n_patch - expand * 2 + 1)
                patch_H = img[x_start:x_start + w_patch * n_patch + expand * 2,
                              y_start:y_start + w_patch * n_patch + expand * 2]
                patch_L = util_deblur.uniformConv2d(patch_H, PSF)

                x_L = util.uint2single(patch_L)
                x_L = util.single2tensor4(x_L)
                x_gt = util.uint2single(patch_H[expand:-expand,
                                                expand:-expand])
                x_gt = util.single2tensor4(x_gt)
                y.append(x_L)
                x.append(x_gt)

            ab = F.softplus(ab_param)
            loss = 0
            x_E = []
            for i in range(n_batch):
                yy = y[i].to(device)
                xx = x[i].to(device)
                xE = model.forward_patchtranslate(yy, ab[i])
                loss += F.l1_loss(xE, xx)
                x_E.append(util.tensor2uint(xE))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('loss {}'.format(loss.item()))

            gt = util.tensor2uint(x[-1])
            # patch_E_all = [util.tensor2uint(pp) for pp in x_E]
            # patch_E_z = np.hstack((patch_E_all[::2]))
            # patch_E_x = np.hstack((patch_E_all[1::2]))
            # patch_E_show = np.vstack((patch_E_z,patch_E_x))
            cv2.imshow('res', np.hstack((gt, x_E[-1])))
            cv2.waitKey(1)

        ab_numpy = ab.detach().cpu().numpy()[:, :, 0, 0]
        torch.save(model.state_dict(), 'usrnet_bench.pth')
        np.savetxt('ab_bench.txt', ab_numpy)
Пример #3
0
def main():
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    #PSF_grid = np.load('./data/Schuler_PSF01.npz')['PSF']
    #PSF_grid = np.load('./data/Schuler_PSF_facade.npz')['PSF']
    PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF']
    #PSF_grid = np.load('./data/Schuler_PSF03.npz')['PSF']
    #PSF_grid = np.load('./data/PSF.npz')['PSF']
    #print(PSF_grid.shape)

    PSF_grid = PSF_grid.astype(np.float32)

    gx, gy = PSF_grid.shape[:2]
    for xx in range(gx):
        for yy in range(gy):
            PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy],
                                                         axis=(0, 1))

    #PSF_grid = PSF_grid[:,1:-1,...]
    # ----------------------------------------
    # load model
    # ----------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=3,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.load_state_dict(
        torch.load('/home/xiu/databag/deblur/patchwise1x2.pth'), strict=True)
    #model.load_state_dict(torch.load('usrnet_ZEMAX_finetune.pth'), strict=True)
    model.eval()
    #model.train()
    for _, v in model.named_parameters():
        v.requires_grad = False
    #	v.requires_grad = False
    model = model.to(device)

    mean_PSF = np.mean(PSF_grid, axis=(0, 1))
    mean_PSF = mean_PSF / np.sum(mean_PSF, axis=(0, 1))
    k_size = mean_PSF.shape[0] // 2
    for img_id in range(8, 10):
        img_H = cv2.imread(
            '/home/xiu/workspace/dwgan/new_image/image/{}_new.jpg'.format(
                img_id))
        img_H = img_H.astype(np.float32)
        img_H = np.pad(img_H, ((k_size, k_size), (k_size, k_size), (0, 0)))
        img_L = util_deblur.uniformConv2d(img_H, PSF_grid[1, 2])
        img_L = img_L.astype(np.float32)
        #img_L = util_deblur.blockConv2d(img_H,PSF_grid)
        img_E = np.zeros_like(img_L)

        img_E_deconv = []
        img_E_denoise = []
        for i in range(8):
            img_E_deconv.append(np.zeros_like(img_L))
            img_E_denoise.append(np.zeros_like(img_L))

        weight_E = np.zeros_like(img_L)

        patch_size = 2 * 128
        num_patch = 2
        p_size = patch_size // num_patch
        expand = PSF_grid.shape[2] // 2

        #positional alpha-beta parameters for HQS
        #ab_numpy = np.ones((num_patch*num_patch,17,1,1),dtype=np.float32)*0.1
        #ab_numpy[:,0,:,:] = 0.01
        ab_numpy = np.loadtxt('ab_ZEMAX_finetune.txt').astype(
            np.float32).reshape(6, 8, 17, 3)
        #ab_numpy[...] = 0.1

        #ab_numpy = np.loadtxt('ab_ZEMAX.txt').astype(np.float32).reshape(6,8,11,3)
        #ab_numpy = ab_numpy[:,1:-1,:,:]

        #ab_numpy = ab_numpy[...,None,None]
        ab = torch.tensor(ab_numpy, device=device, requires_grad=False)

        #save img_L

        #while running:
        for px_start in range(0, 6 - 2 + 1, 2):
            for py_start in range(0, 8 - 2 + 1, 2):

                #px_start = np.random.randint(0,PSF_grid.shape[0]+1-num_patch)
                #py_start = np.random.randint(0,PSF_grid.shape[1]+1-num_patch)
                # x = util.uint2single(patch_L)
                block_size = patch_size // num_patch
                patch_L = img_L[px_start * p_size:(px_start + num_patch) *
                                p_size, py_start * p_size:py_start * p_size +
                                num_patch * p_size, :]
                #patch_H = img_H[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:]
                #block_expand = expand*2
                block_expand = expand
                #block_expand = 1
                if block_expand > 0:
                    patch_L_wrap = util_deblur.wrap_boundary_liu(
                        patch_L, (patch_size + block_expand * 2,
                                  patch_size + block_expand * 2))
                    #centralize
                    patch_L_wrap = np.hstack(
                        (patch_L_wrap[:, -block_expand:, :],
                         patch_L_wrap[:, :patch_size + block_expand, :]))
                    patch_L_wrap = np.vstack(
                        (patch_L_wrap[-block_expand:, :, :],
                         patch_L_wrap[:patch_size + block_expand, :, :]))
                else:
                    patch_L_wrap = patch_L
                if block_expand > 0:
                    x = util.uint2single(patch_L_wrap)
                else:
                    x = util.uint2single(patch_L)
                #x_blocky = []
                #for h_ in range(num_patch):
                #	for w_ in range(num_patch):
                #		x_blocky.append(x[w_*block_size:w_*block_size+block_size+block_expand*2,\
                #			h_*block_size:h_*block_size+block_size+block_expand*2:])
                #x_blocky = [util.single2tensor4(el) for el in x_blocky]
                #x_blocky = torch.cat(x_blocky,dim=0)

                # x = util.single2tensor4(x)

                # x_blocky = torch.cat(torch.chunk(x,num_patch,dim=2),dim=0)
                # x_blocky = torch.cat(torch.chunk(x_blocky,num_patch,dim=3),dim=0)

                #k_all = []
                #for w_ in range(num_patch):
                #	for h_ in range(num_patch):
                #		k_all.append(util.single2tensor4(PSF_patch[h_,w_]))
                #k = torch.cat(k_all,dim=0)
                k = util.single2tensor4(PSF_grid[1, 2])
                x = util.single2tensor4(x)

                [x_blocky, k] = [el.to(device) for el in [x, k]]

                cd = F.softplus(ab[px_start:px_start + num_patch,
                                   py_start:py_start + num_patch])
                cd = cd.view(num_patch**2, 2 * 8 + 1, 3)

                x_E = model.forward_globaldeconv(x_blocky,
                                                 k,
                                                 cd,
                                                 patch_sz=patch_size)
                x_E = x_E[:-1]

                patch_L = patch_L_wrap.astype(np.uint8)

                patch_E = util.tensor2uint(x_E[-1])
                patch_E_all = [util.tensor2uint(pp) for pp in x_E]

                #patch_E_z = np.hstack((patch_E_all[::2]))
                #patch_E_x = np.hstack((patch_E_all[1::2]))

                #patch_E_show = np.vstack((patch_E_z,patch_E_x))
                #if block_expand>0:
                #	show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E))
                #else:
                #	show = np.hstack((patch_L,patch_E))

                #get kernel
                for i in range(8):
                    img_E_deconv[i][px_start * p_size:(px_start + num_patch) *
                                    p_size,
                                    py_start * p_size:py_start * p_size +
                                    num_patch * p_size, :] += patch_E_all[2 *
                                                                          i]
                    img_E_denoise[i][px_start * p_size:(px_start + num_patch) *
                                     p_size, py_start *
                                     p_size:py_start * p_size + num_patch *
                                     p_size, :] += patch_E_all[2 * i + 1]
                weight_E[px_start * p_size:(px_start + num_patch) * p_size,
                         py_start * p_size:py_start * p_size +
                         num_patch * p_size, :] += 1.0

                #cv2.imshow('stage',patch_E_show)
                #cv2.imshow('HL',show)
                #cv2.imshow('RGB',rgb)
                #key = cv2.waitKey(-1)
                #if key==ord('n'):
                #	break

        img_E = img_E / weight_E
        img_E_deconv = [pp / weight_E for pp in img_E_deconv]
        img_E_denoise = [pp / weight_E for pp in img_E_denoise]

        # img_L = img_L.astype(np.uint8)
        # img_E = img_E.astype(np.uint8)
        # img_E_deconv = img_E_deconv.astype(np.uint8)
        # img_E_denoise = img_E_denoise.astype(np.uint8)
        # cv2.imshow('imE',img_E)
        # cv2.imshow('imE_deconv',img_E_deconv)
        # cv2.imshow('imE_denoise',img_E_denoise)
        # cv2.imshow('imL',img_L)
        #for i in range(5):
        #zk = img_E_deconv[i]
        print(i)

        xk = img_E_deconv[-3]
        #zk = zk.astype(np.uint8)
        xk = xk.astype(np.uint8)
        cv2.imwrite(
            '/home/xiu/workspace/dwgan/new_image/image/fakepatch1x2-{}.png'.
            format(img_id), xk)