def main(): root = '/home/xiu/databag/deblur/pretrain/patchwise' logger = SummaryWriter(root) # ---------------------------------------- # load kernels # ---------------------------------------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF'] PSF_grid = PSF_grid.astype(np.float32) kw,kh = PSF_grid.shape[:2] for w_ in range(kw): for h_ in range(kh): PSF_grid[w_,h_] = PSF_grid[w_,h_]/np.sum(PSF_grid[w_,h_],axis=(0,1)) # ---------------------------------------- # build # ---------------------------------------- model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=3, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") model.train() for _, v in model.named_parameters(): v.requires_grad = True model = model.to(device) train_imgs = glob.glob('/home/xiu/databag/deblur/images/*/**.png',recursive=True) train_imgs.sort() #global model. 3x4 n_epoch = 200 n_stage = 5 n_batch = 3 w_patch = 128 n_patch = 2 #also 3x4,but different strategy ab_buffer = np.ones((1,1,2*n_stage+1,3),dtype=np.float32)*0.1 ab_param = torch.tensor(ab_buffer,device=device,requires_grad=True) params = [] params += [{"params":[ab_param],"lr":1e-4}] for key,value in model.named_parameters(): params += [{"params":[value],"lr":1e-4}] optimizer = torch.optim.Adam(params,lr=1e-4) img_index = np.arange(len(train_imgs)) global_iter = 0 expand = PSF_grid.shape[2]//2*2 #using this small PSF only. PSF = PSF_grid[1,2] for epoch in range(n_epoch): np.random.shuffle(img_index) for iteration in range(len(train_imgs)//n_batch): imgs = [] for ii in range(n_batch): imgs.append(cv2.imread(train_imgs[img_index[iteration*n_batch+ii]])) global_iter += 1 vis_H = [] vis_L = [] vis_E = [] x = [] y = [] for img in imgs: w,h = img.shape[:2] for _ in range(n_patch*n_patch): x_start = np.random.randint(0,w-w_patch-expand*2+1) y_start = np.random.randint(0,h-w_patch-expand*2+1) patch_H = img[x_start:x_start+w_patch+expand*2,y_start:y_start+w_patch+expand*2] patch_L = util_deblur.uniformConv2d(patch_H,PSF) vis_H.append(patch_H[expand:-expand,expand:-expand]) vis_L.append(patch_L[expand//2:-expand//2,expand//2:-expand//2]) x_L = util.uint2single(patch_L) x_L = util.single2tensor4(x_L) x_gt = util.uint2single(patch_H[expand:-expand,expand:-expand]) x_gt = util.single2tensor4(x_gt) y.append(x_L) x.append(x_gt) ab = F.softplus(ab_param) loss = 0 k = util.single2tensor4(PSF) k = k.to(device) for i in range(n_batch*n_patch*n_patch): yy = y[i].to(device) xx = x[i].to(device) #xE = model.forward_patchdeconv(yy,kk,ab[0],[1,1],w_patch) xE = model.forward_globaldeconv(yy,k,ab[0],w_patch) loss += F.l1_loss(xE[-2],xx) patch_E = xE[-2].clone().detach().cpu() vis_E.append(util.tensor2uint(patch_E)) optimizer.zero_grad() loss.backward() optimizer.step() if global_iter % 10 ==0 : print('iter {}: loss{}.'.format(global_iter,loss.item())) logger.add_scalar('train-loss',loss.item(),global_iter) for i in range(n_batch): show1 = np.hstack((vis_H[i],vis_L[i],vis_E[i])) logger.add_image('show-{}'.format(i),util.uint2tensor3(show1[:,:,::-1])) logger.flush() ab_numpy = ab.detach().cpu().numpy()[:,:,0,0] ab_numpy = ab_numpy.flatten() torch.save(model.state_dict(),os.path.join(root,'usrnet_patchwise1x2_epoch{}.pth'.format(epoch))) np.savetxt(os.path.join(root,'patchwise_ours.txt'),ab_numpy)
def main(): # ---------------------------------------- # load kernels # ---------------------------------------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') all_PSFs = load_kernels('./data') # ---------------------------------------- # build # ---------------------------------------- model = net(n_iter=8, h_nc=64, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=3, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") model.train() for _, v in model.named_parameters(): v.requires_grad = True model = model.to(device) train_imgs = glob.glob('/home/xiu/databag/deblur/images/*/**.png', recursive=True) train_imgs.sort() n_stage = 5 n_batch = 9 n_epoch = 200 w_patch = 128 n_patch = 1 ab_buffer = np.ones((n_batch, n_patch * n_patch, 2 * n_stage + 1, 3), dtype=np.float32) * 0.1 ab_param = torch.tensor(ab_buffer, device=device, requires_grad=False) params = [] #params += [{"params":[ab_param],"lr":1e-4}] for key, value in model.named_parameters(): params += [{"params": [value], "lr": 1e-4}] optimizer = torch.optim.Adam(params, lr=1e-4) img_index = np.arange(len(train_imgs)) global_iter = 0 PSFs = [] for i in range(n_batch): # if i % 2==0: PSFs.append(all_PSFs[0][0:1, 0:1]) # else: # PSFs.append(rand_kernels(n_patch)) expands = [] for i in range(n_batch): expands.append(PSFs[i].shape[2] // 2) for epoch in range(n_epoch): np.random.shuffle(img_index) for iteration in range(len(train_imgs) // n_batch): imgs = [] for ii in range(n_batch): imgs.append( cv2.imread(train_imgs[img_index[iteration * n_batch + ii]])) global_iter += 1 if global_iter % 100 == 0: PSFs = [] for i in range(n_batch): PSFs.append(all_PSFs[0][0:1, 0:1]) expands = [] for i in range(n_batch): expands.append(PSFs[i].shape[2] // 2) #get new kernel. x = [] y = [] for img, expand, PSF in zip(imgs, expands, PSFs): w, h = img.shape[:2] x_start = np.random.randint( 0, w - w_patch * n_patch - expand * 2 + 1) y_start = np.random.randint( 0, h - w_patch * n_patch - expand * 2 + 1) patch_H = img[x_start:x_start + w_patch * n_patch + expand * 2, y_start:y_start + w_patch * n_patch + expand * 2] patch_L = util_deblur.uniformConv2d(patch_H, PSF) x_L = util.uint2single(patch_L) x_L = util.single2tensor4(x_L) x_gt = util.uint2single(patch_H[expand:-expand, expand:-expand]) x_gt = util.single2tensor4(x_gt) y.append(x_L) x.append(x_gt) ab = F.softplus(ab_param) loss = 0 x_E = [] for i in range(n_batch): yy = y[i].to(device) xx = x[i].to(device) xE = model.forward_patchtranslate(yy, ab[i]) loss += F.l1_loss(xE, xx) x_E.append(util.tensor2uint(xE)) optimizer.zero_grad() loss.backward() optimizer.step() print('loss {}'.format(loss.item())) gt = util.tensor2uint(x[-1]) # patch_E_all = [util.tensor2uint(pp) for pp in x_E] # patch_E_z = np.hstack((patch_E_all[::2])) # patch_E_x = np.hstack((patch_E_all[1::2])) # patch_E_show = np.vstack((patch_E_z,patch_E_x)) cv2.imshow('res', np.hstack((gt, x_E[-1]))) cv2.waitKey(1) ab_numpy = ab.detach().cpu().numpy()[:, :, 0, 0] torch.save(model.state_dict(), 'usrnet_bench.pth') np.savetxt('ab_bench.txt', ab_numpy)
def main(): # ---------------------------------------- # load kernels # ---------------------------------------- #PSF_grid = np.load('./data/Schuler_PSF01.npz')['PSF'] #PSF_grid = np.load('./data/Schuler_PSF_facade.npz')['PSF'] PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF'] #PSF_grid = np.load('./data/Schuler_PSF03.npz')['PSF'] #PSF_grid = np.load('./data/PSF.npz')['PSF'] #print(PSF_grid.shape) PSF_grid = PSF_grid.astype(np.float32) gx, gy = PSF_grid.shape[:2] for xx in range(gx): for yy in range(gy): PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy], axis=(0, 1)) #PSF_grid = PSF_grid[:,1:-1,...] # ---------------------------------------- # load model # ---------------------------------------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=3, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") model.load_state_dict( torch.load('/home/xiu/databag/deblur/patchwise1x2.pth'), strict=True) #model.load_state_dict(torch.load('usrnet_ZEMAX_finetune.pth'), strict=True) model.eval() #model.train() for _, v in model.named_parameters(): v.requires_grad = False # v.requires_grad = False model = model.to(device) mean_PSF = np.mean(PSF_grid, axis=(0, 1)) mean_PSF = mean_PSF / np.sum(mean_PSF, axis=(0, 1)) k_size = mean_PSF.shape[0] // 2 for img_id in range(8, 10): img_H = cv2.imread( '/home/xiu/workspace/dwgan/new_image/image/{}_new.jpg'.format( img_id)) img_H = img_H.astype(np.float32) img_H = np.pad(img_H, ((k_size, k_size), (k_size, k_size), (0, 0))) img_L = util_deblur.uniformConv2d(img_H, PSF_grid[1, 2]) img_L = img_L.astype(np.float32) #img_L = util_deblur.blockConv2d(img_H,PSF_grid) img_E = np.zeros_like(img_L) img_E_deconv = [] img_E_denoise = [] for i in range(8): img_E_deconv.append(np.zeros_like(img_L)) img_E_denoise.append(np.zeros_like(img_L)) weight_E = np.zeros_like(img_L) patch_size = 2 * 128 num_patch = 2 p_size = patch_size // num_patch expand = PSF_grid.shape[2] // 2 #positional alpha-beta parameters for HQS #ab_numpy = np.ones((num_patch*num_patch,17,1,1),dtype=np.float32)*0.1 #ab_numpy[:,0,:,:] = 0.01 ab_numpy = np.loadtxt('ab_ZEMAX_finetune.txt').astype( np.float32).reshape(6, 8, 17, 3) #ab_numpy[...] = 0.1 #ab_numpy = np.loadtxt('ab_ZEMAX.txt').astype(np.float32).reshape(6,8,11,3) #ab_numpy = ab_numpy[:,1:-1,:,:] #ab_numpy = ab_numpy[...,None,None] ab = torch.tensor(ab_numpy, device=device, requires_grad=False) #save img_L #while running: for px_start in range(0, 6 - 2 + 1, 2): for py_start in range(0, 8 - 2 + 1, 2): #px_start = np.random.randint(0,PSF_grid.shape[0]+1-num_patch) #py_start = np.random.randint(0,PSF_grid.shape[1]+1-num_patch) # x = util.uint2single(patch_L) block_size = patch_size // num_patch patch_L = img_L[px_start * p_size:(px_start + num_patch) * p_size, py_start * p_size:py_start * p_size + num_patch * p_size, :] #patch_H = img_H[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] #block_expand = expand*2 block_expand = expand #block_expand = 1 if block_expand > 0: patch_L_wrap = util_deblur.wrap_boundary_liu( patch_L, (patch_size + block_expand * 2, patch_size + block_expand * 2)) #centralize patch_L_wrap = np.hstack( (patch_L_wrap[:, -block_expand:, :], patch_L_wrap[:, :patch_size + block_expand, :])) patch_L_wrap = np.vstack( (patch_L_wrap[-block_expand:, :, :], patch_L_wrap[:patch_size + block_expand, :, :])) else: patch_L_wrap = patch_L if block_expand > 0: x = util.uint2single(patch_L_wrap) else: x = util.uint2single(patch_L) #x_blocky = [] #for h_ in range(num_patch): # for w_ in range(num_patch): # x_blocky.append(x[w_*block_size:w_*block_size+block_size+block_expand*2,\ # h_*block_size:h_*block_size+block_size+block_expand*2:]) #x_blocky = [util.single2tensor4(el) for el in x_blocky] #x_blocky = torch.cat(x_blocky,dim=0) # x = util.single2tensor4(x) # x_blocky = torch.cat(torch.chunk(x,num_patch,dim=2),dim=0) # x_blocky = torch.cat(torch.chunk(x_blocky,num_patch,dim=3),dim=0) #k_all = [] #for w_ in range(num_patch): # for h_ in range(num_patch): # k_all.append(util.single2tensor4(PSF_patch[h_,w_])) #k = torch.cat(k_all,dim=0) k = util.single2tensor4(PSF_grid[1, 2]) x = util.single2tensor4(x) [x_blocky, k] = [el.to(device) for el in [x, k]] cd = F.softplus(ab[px_start:px_start + num_patch, py_start:py_start + num_patch]) cd = cd.view(num_patch**2, 2 * 8 + 1, 3) x_E = model.forward_globaldeconv(x_blocky, k, cd, patch_sz=patch_size) x_E = x_E[:-1] patch_L = patch_L_wrap.astype(np.uint8) patch_E = util.tensor2uint(x_E[-1]) patch_E_all = [util.tensor2uint(pp) for pp in x_E] #patch_E_z = np.hstack((patch_E_all[::2])) #patch_E_x = np.hstack((patch_E_all[1::2])) #patch_E_show = np.vstack((patch_E_z,patch_E_x)) #if block_expand>0: # show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E)) #else: # show = np.hstack((patch_L,patch_E)) #get kernel for i in range(8): img_E_deconv[i][px_start * p_size:(px_start + num_patch) * p_size, py_start * p_size:py_start * p_size + num_patch * p_size, :] += patch_E_all[2 * i] img_E_denoise[i][px_start * p_size:(px_start + num_patch) * p_size, py_start * p_size:py_start * p_size + num_patch * p_size, :] += patch_E_all[2 * i + 1] weight_E[px_start * p_size:(px_start + num_patch) * p_size, py_start * p_size:py_start * p_size + num_patch * p_size, :] += 1.0 #cv2.imshow('stage',patch_E_show) #cv2.imshow('HL',show) #cv2.imshow('RGB',rgb) #key = cv2.waitKey(-1) #if key==ord('n'): # break img_E = img_E / weight_E img_E_deconv = [pp / weight_E for pp in img_E_deconv] img_E_denoise = [pp / weight_E for pp in img_E_denoise] # img_L = img_L.astype(np.uint8) # img_E = img_E.astype(np.uint8) # img_E_deconv = img_E_deconv.astype(np.uint8) # img_E_denoise = img_E_denoise.astype(np.uint8) # cv2.imshow('imE',img_E) # cv2.imshow('imE_deconv',img_E_deconv) # cv2.imshow('imE_denoise',img_E_denoise) # cv2.imshow('imL',img_L) #for i in range(5): #zk = img_E_deconv[i] print(i) xk = img_E_deconv[-3] #zk = zk.astype(np.uint8) xk = xk.astype(np.uint8) cv2.imwrite( '/home/xiu/workspace/dwgan/new_image/image/fakepatch1x2-{}.png'. format(img_id), xk)