Ejemplo n.º 1
0
def main():
    logger = SummaryWriter('/home/xiu/databag/deblur/pretrain/full/')
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    all_PSFs = load_kernels('./data')
    # ----------------------------------------
    # build
    # ----------------------------------------
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=3,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    train_imgs = glob.glob('/home/xiu/databag/deblur/images/*/**.png',
                           recursive=True)
    train_imgs.sort()

    n_stage = 5
    n_batch = 3
    n_epoch = 200

    w_patch = 128
    n_patch = 2

    ab_buffer = np.ones((n_batch, n_patch * n_patch, 2 * n_stage + 1, 3),
                        dtype=np.float32) * 0.1
    ab_param = torch.tensor(ab_buffer, device=device, requires_grad=True)

    params = []
    params += [{"params": [ab_param], "lr": 1e-4}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 1e-4}]
    optimizer = torch.optim.Adam(params, lr=1e-4)

    img_index = np.arange(len(train_imgs))

    global_iter = 0

    PSFs = []
    for i in range(n_batch):
        if i % 2 == 0:
            PSFs.append(get_kernels(all_PSFs, n_patch))
        else:
            PSFs.append(rand_kernels(n_patch))
    expands = []
    for i in range(n_batch):
        expands.append(PSFs[i].shape[2] // 2)

    for epoch in range(n_epoch):
        np.random.shuffle(img_index)
        for iteration in range(len(train_imgs) // n_batch):
            imgs = []
            for ii in range(n_batch):
                imgs.append(
                    cv2.imread(train_imgs[img_index[iteration * n_batch +
                                                    ii]]))

            global_iter += 1
            if global_iter % 100 == 0:
                PSFs = []
                for i in range(n_batch):
                    if i % 2 == 0:
                        PSFs.append(get_kernels(all_PSFs, n_patch))
                    else:
                        PSFs.append(rand_kernels(n_patch))
                expands = []
                for i in range(n_batch):
                    expands.append(PSFs[i].shape[2] // 2)

            x = []
            y = []
            k = []
            vis_L = []
            vis_H = []
            vis_E = []
            for img, expand, PSF in zip(imgs, expands, PSFs):
                w, h = img.shape[:2]
                x_start = np.random.randint(
                    0, w - w_patch * n_patch - expand * 2 + 1)
                y_start = np.random.randint(
                    0, h - w_patch * n_patch - expand * 2 + 1)
                patch_H = img[x_start:x_start + w_patch * n_patch + expand * 2,
                              y_start:y_start + w_patch * n_patch + expand * 2]
                patch_L = util_deblur.blockConv2d(patch_H, PSF, expand)

                vis_L.append(patch_L)
                vis_H.append(patch_H[expand:-expand, expand:-expand])

                patch_L_wrap = util_deblur.wrap_boundary_liu(
                    patch_L, (w_patch * n_patch + expand * 2,
                              w_patch * n_patch + expand * 2))
                patch_L_wrap = np.hstack(
                    (patch_L_wrap[:, -expand:, :],
                     patch_L_wrap[:, :w_patch * n_patch + expand, :]))
                patch_L_wrap = np.vstack(
                    (patch_L_wrap[-expand:, :, :],
                     patch_L_wrap[:w_patch * n_patch + expand, :, :]))
                x_L = util.uint2single(patch_L_wrap)

                x_blocky = []
                for h_ in range(n_patch):
                    for w_ in range(n_patch):
                        x_blocky.append(x_L[w_*w_patch:w_*w_patch+w_patch+expand*2,\
                         h_*w_patch:h_*w_patch+w_patch+expand*2:])
                x_blocky = [util.single2tensor4(el) for el in x_blocky]
                x_blocky = torch.cat(x_blocky, dim=0)

                k_all = []
                for w_ in range(n_patch):
                    for h_ in range(n_patch):
                        k_all.append(util.single2tensor4(PSF[h_, w_]))

                k_all = torch.cat(k_all, dim=0)

                x_gt = util.uint2single(patch_H[expand:-expand,
                                                expand:-expand])
                x_gt = util.single2tensor4(x_gt)
                y.append(x_blocky)
                x.append(x_gt)
                k.append(k_all)

            ab = F.softplus(ab_param)
            loss = 0
            for i in range(n_batch):
                yy = y[i].to(device)
                kk = k[i].to(device)
                xx = x[i].to(device)
                xE = model.forward_patchdeconv(yy, kk, ab[i],
                                               [n_patch, n_patch], w_patch)
                loss += F.l1_loss(xE[-2], xx)
                vis_E.append(util.tensor2uint(xE[-2]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if global_iter % 10 == 0:
                print('iter {}: loss{}.'.format(global_iter, loss.item()))
                logger.add_scalar('train-loss', loss.item(), global_iter)
                for i in range(n_batch):
                    show1 = np.hstack((vis_H[i], vis_L[i], vis_E[i]))
                    logger.add_image('show-{}'.format(i),
                                     util.uint2tensor3(show1[:, :, ::-1]))
                    logger.flush()
        ab_numpy = ab.detach().cpu().numpy()[:, :, 0, 0]
        ab_numpy = ab_numpy.flatten()
        torch.save(model.state_dict(), 'usrnet_ours_epoch{}.pth'.format(epoch))
        np.savetxt('ab_ours.txt', ab_numpy)
Ejemplo n.º 2
0
def main():
    logger = SummaryWriter('/home/xiu/databag/deblur/pretrain/finetune/')
    # ----------------------------------------
    # load kernels
    # ----------------------------------------
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF']
    PSF_grid = PSF_grid.astype(np.float32)
    kw, kh = PSF_grid.shape[:2]
    for w_ in range(kw):
        for h_ in range(kh):
            PSF_grid[w_, h_] = PSF_grid[w_, h_] / np.sum(PSF_grid[w_, h_],
                                                         axis=(0, 1))
    # ----------------------------------------
    # build
    # ----------------------------------------
    model = net(n_iter=8,
                h_nc=64,
                in_nc=4,
                out_nc=3,
                nc=[64, 128, 256, 512],
                nb=3,
                act_mode="R",
                downsample_mode='strideconv',
                upsample_mode="convtranspose")
    model.load_state_dict(
        torch.load('/home/xiu/databag/deblur/usrnet_ours_epoch10.pth'),
        strict=True)
    model.train()
    for _, v in model.named_parameters():
        v.requires_grad = True
    model = model.to(device)

    train_imgs = glob.glob('/home/xiu/databag/deblur/images/*/**.png',
                           recursive=True)
    train_imgs.sort()

    n_stage = 5
    n_batch = 3
    n_epoch = 200

    w_patch = 128
    n_patch = 2

    ab_buffer = np.ones(
        (n_batch, PSF_grid.shape[0] * PSF_grid.shape[1], 2 * n_stage + 1, 3),
        dtype=np.float32) * 0.1
    ab_param = torch.tensor(ab_buffer, device=device, requires_grad=True)

    params = []
    params += [{"params": [ab_param], "lr": 1e-4}]
    for key, value in model.named_parameters():
        params += [{"params": [value], "lr": 1e-4}]
    optimizer = torch.optim.Adam(params, lr=1e-4)

    img_index = np.arange(len(train_imgs))

    global_iter = 0

    expand = PSF_grid.shape[2] // 2

    for epoch in range(n_epoch):
        np.random.shuffle(img_index)
        for iteration in range(len(train_imgs) // n_batch):
            imgs = []
            for ii in range(n_batch):
                imgs.append(
                    cv2.imread(train_imgs[img_index[iteration * n_batch +
                                                    ii]]))
            global_iter += 1

            x = []
            y = []
            k = []
            vis_L = []
            vis_H = []
            vis_E = []
            for img in imgs:
                w, h = img.shape[:2]
                x_start = np.random.randint(
                    0, w - w_patch * n_patch - expand * 2 + 1)
                y_start = np.random.randint(
                    0, h - w_patch * n_patch - expand * 2 + 1)
                px = np.random.randint(0, PSF_grid.shape[0] - n_patch)
                py = np.random.randint(0, PSF_grid.shape[1] - n_patch)

                patch_H = img[x_start:x_start + w_patch * n_patch + expand * 2,
                              y_start:y_start + w_patch * n_patch + expand * 2]
                patch_L = util_deblur.blockConv2d(
                    patch_H, PSF_grid[px:px + n_patch, py:py + n_patch],
                    expand)

                vis_L.append(patch_L)
                vis_H.append(patch_H[expand:-expand, expand:-expand])

                patch_L_wrap = util_deblur.wrap_boundary_liu(
                    patch_L, (w_patch * n_patch + expand * 2,
                              w_patch * n_patch + expand * 2))
                patch_L_wrap = np.hstack(
                    (patch_L_wrap[:, -expand:, :],
                     patch_L_wrap[:, :w_patch * n_patch + expand, :]))
                patch_L_wrap = np.vstack(
                    (patch_L_wrap[-expand:, :, :],
                     patch_L_wrap[:w_patch * n_patch + expand, :, :]))
                x_L = util.uint2single(patch_L_wrap)

                x_blocky = []
                for h_ in range(n_patch):
                    for w_ in range(n_patch):
                        x_blocky.append(x_L[w_*w_patch:w_*w_patch+w_patch+expand*2,\
                         h_*w_patch:h_*w_patch+w_patch+expand*2:])
                x_blocky = [util.single2tensor4(el) for el in x_blocky]
                x_blocky = torch.cat(x_blocky, dim=0)

                k_all = []
                for w_ in range(n_patch):
                    for h_ in range(n_patch):
                        k_all.append(
                            util.single2tensor4(PSF_grid[h_ + px, w_ + py]))

                k_all = torch.cat(k_all, dim=0)

                x_gt = util.uint2single(patch_H[expand:-expand,
                                                expand:-expand])
                x_gt = util.single2tensor4(x_gt)
                y.append(x_blocky)
                x.append(x_gt)
                k.append(k_all)

            ab = F.softplus(ab_param)
            loss = 0
            for i in range(n_batch):
                yy = y[i].to(device)
                kk = k[i].to(device)
                xx = x[i].to(device)
                xE = model.forward_patchdeconv(yy, kk, ab[i],
                                               [n_patch, n_patch], w_patch)
                loss += F.l1_loss(xE[-2], xx)
                vis_E.append(util.tensor2uint(xE[-2]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if global_iter % 10 == 0:
                print('iter {}: loss{}.'.format(global_iter, loss.item()))
                logger.add_scalar('finetune-loss', loss.item(), global_iter)
                for i in range(n_batch):
                    show1 = np.hstack((vis_H[i], vis_L[i], vis_E[i]))
                    logger.add_image('show-{}'.format(i),
                                     util.uint2tensor3(show1[:, :, ::-1]))
                    logger.flush()
Ejemplo n.º 3
0
def main():
	root  = '/home/xiu/databag/deblur/pretrain/patchwise'
	logger = SummaryWriter(root)
	# ----------------------------------------
	# load kernels
	# ----------------------------------------
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

	PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF']
	PSF_grid = PSF_grid.astype(np.float32)
	kw,kh = PSF_grid.shape[:2]
	for w_ in range(kw):
		for h_ in range(kh):
			PSF_grid[w_,h_] = PSF_grid[w_,h_]/np.sum(PSF_grid[w_,h_],axis=(0,1))
	# ----------------------------------------
	# build
	# ----------------------------------------
	model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
					nb=3, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
	model.train()
	for _, v in model.named_parameters():
		v.requires_grad = True
	model = model.to(device)

	train_imgs = glob.glob('/home/xiu/databag/deblur/images/*/**.png',recursive=True)
	train_imgs.sort()

	#global model. 3x4
	n_epoch = 200
	n_stage = 5
	n_batch = 3
	w_patch = 128
	n_patch = 2
	#also 3x4,but different strategy
	
	ab_buffer = np.ones((1,1,2*n_stage+1,3),dtype=np.float32)*0.1
	ab_param = torch.tensor(ab_buffer,device=device,requires_grad=True)

	params = []
	params += [{"params":[ab_param],"lr":1e-4}]
	for key,value in model.named_parameters():
		params += [{"params":[value],"lr":1e-4}]
	optimizer = torch.optim.Adam(params,lr=1e-4)

	img_index = np.arange(len(train_imgs))

	global_iter = 0
	expand = PSF_grid.shape[2]//2*2
	
	#using this small PSF only.
	PSF = PSF_grid[1,2]
	
	for epoch in range(n_epoch):
		np.random.shuffle(img_index)
		for iteration in range(len(train_imgs)//n_batch):
			imgs = []
			for ii in range(n_batch):
				imgs.append(cv2.imread(train_imgs[img_index[iteration*n_batch+ii]]))
			global_iter += 1

			vis_H = []
			vis_L = []
			vis_E = []

			x = []
			y = []
			for img in imgs:
				w,h = img.shape[:2]
				for _ in range(n_patch*n_patch):
					x_start = np.random.randint(0,w-w_patch-expand*2+1)
					y_start = np.random.randint(0,h-w_patch-expand*2+1)
					patch_H = img[x_start:x_start+w_patch+expand*2,y_start:y_start+w_patch+expand*2]
					patch_L = util_deblur.uniformConv2d(patch_H,PSF)

					vis_H.append(patch_H[expand:-expand,expand:-expand])
					vis_L.append(patch_L[expand//2:-expand//2,expand//2:-expand//2])
					x_L = util.uint2single(patch_L)
					x_L = util.single2tensor4(x_L)
					x_gt = util.uint2single(patch_H[expand:-expand,expand:-expand])
					x_gt = util.single2tensor4(x_gt)
					y.append(x_L)
					x.append(x_gt)
			
			ab = F.softplus(ab_param)
			loss = 0
			k = util.single2tensor4(PSF)
			k = k.to(device)
			for i in range(n_batch*n_patch*n_patch):
				yy = y[i].to(device)
				xx = x[i].to(device)
				#xE = model.forward_patchdeconv(yy,kk,ab[0],[1,1],w_patch)
				xE = model.forward_globaldeconv(yy,k,ab[0],w_patch)
				loss += F.l1_loss(xE[-2],xx)
				patch_E = xE[-2].clone().detach().cpu()
				vis_E.append(util.tensor2uint(patch_E))
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			if global_iter % 10 ==0 :
				print('iter {}: loss{}.'.format(global_iter,loss.item()))
				logger.add_scalar('train-loss',loss.item(),global_iter)
				for i in range(n_batch):
					show1 = np.hstack((vis_H[i],vis_L[i],vis_E[i]))
					logger.add_image('show-{}'.format(i),util.uint2tensor3(show1[:,:,::-1]))
					logger.flush()	
		ab_numpy = ab.detach().cpu().numpy()[:,:,0,0]
		ab_numpy = ab_numpy.flatten()
		torch.save(model.state_dict(),os.path.join(root,'usrnet_patchwise1x2_epoch{}.pth'.format(epoch)))
		np.savetxt(os.path.join(root,'patchwise_ours.txt'),ab_numpy)