def main(): # ---------------------------------------- # load kernels # ---------------------------------------- #PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF'] PSF_grid = np.load('./data/Heide_PSF_plano_small.npz')['PSF'] PSF_grid = PSF_grid.astype(np.float32) gx, gy = PSF_grid.shape[:2] for xx in range(gx): for yy in range(gy): PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy], axis=(0, 1)) # ---------------------------------------- # load model # ---------------------------------------- stage = 8 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=stage, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") model_code = 'iter800' loaded_state = torch.load( '/home/xiu/databag/deblur/models/plano/uabcnet_{}.pth'.format( model_code)) #strip_state = strip_prefix_if_present(loaded_state,prefix="p.") model.load_state_dict(loaded_state, strict=True) model.eval() for _, v in model.named_parameters(): v.requires_grad = False model = model.to(device) for img_id in range(1, 237): #for img_id in range(1,12): #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video1-3/res/2_{:03d}.bmp'.format(img_id)) #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video/{:08d}.bmp'.format(img_id)) #img_L = cv2.imread('/home/xiu/databag/deblur/ICCV2021/suo_image/{}/AC254-075-A-ML-Zemax(ZMX).bmp'.format(img_id)) #img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/ResolutionChart/Reso.bmp') img_L = cv2.imread( '/home/xiu/databag/deblur/ICCV2021/MPI_data/drain/blurry.jpg') img_L = img_L.astype(np.float32) img_L = img_L[38:-39, 74:-74] img_L = cv2.resize(img_L, dsize=None, fx=0.5, fy=0.5) #img_L = np.pad(img_L,((1,1),(61,62),(0,0)),mode='edge') W, H = img_L.shape[:2] print(gx, gy) num_patch = [gx, gy] #positional alpha-beta parameters for HQS ab_numpy = np.loadtxt( '/home/xiu/databag/deblur/models/plano/ab_{}.txt'.format( model_code)).astype(np.float32).reshape(gx, gy, stage * 2, 3) ab = torch.tensor(ab_numpy, device=device, requires_grad=False) t0 = time.time() px_start = 0 py_start = 0 PSF_patch = PSF_grid[px_start:px_start + num_patch[0], py_start:py_start + num_patch[1]] #block_expand = 1 patch_L = img_L[px_start * W // gx:(px_start + num_patch[0]) * W // gx, py_start * H // gy:(py_start + num_patch[1]) * H // gy, :] p_W, p_H = patch_L.shape[:2] expand = max(PSF_grid.shape[2] // 2, p_W // 16) block_expand = expand patch_L_wrap = util_deblur.wrap_boundary_liu( patch_L, (p_W + block_expand * 2, p_H + block_expand * 2)) #centralize patch_L_wrap = np.hstack((patch_L_wrap[:, -block_expand:, :], patch_L_wrap[:, :p_H + block_expand, :])) patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:, :, :], patch_L_wrap[:p_W + block_expand, :, :])) x = util.uint2single(patch_L_wrap) x = util.single2tensor4(x) k_all = [] for h_ in range(num_patch[1]): for w_ in range(num_patch[0]): k_all.append(util.single2tensor4(PSF_patch[w_, h_])) k = torch.cat(k_all, dim=0) [x, k] = [el.to(device) for el in [x, k]] ab_patch = F.softplus(ab[px_start:px_start + num_patch[0], py_start:py_start + num_patch[1]]) cd = [] for h_ in range(num_patch[1]): for w_ in range(num_patch[0]): cd.append(ab_patch[w_:w_ + 1, h_]) cd = torch.cat(cd, dim=0) x_E = model.forward_patchwise(x, k, cd, num_patch, [W // gx, H // gy]) x_E = x_E[..., block_expand:block_expand + p_W, block_expand:block_expand + p_H] patch_L = patch_L_wrap.astype(np.uint8) patch_E = util.tensor2uint(x_E) #patch_E_z = np.hstack((patch_E_all[::2])) #patch_E_x = np.hstack((patch_E_all[1::2])) #patch_E_show = np.vstack((patch_E_z,patch_E_x)) #if block_expand>0: # show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E)) #else: # show = np.hstack((patch_L,patch_E)) #cv2.imshow('stage',patch_E_show) #cv2.imshow('HL',show) #cv2.imshow('RGB',rgb) #key = cv2.waitKey(-1) #if key==ord('n'): # break t1 = time.time() print(t1 - t0) # print(i) xk = patch_E # #zk = zk.astype(np.uint8) xk = xk.astype(np.uint8) #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/new_image/image/ours-{}.png'.format(img_id),xk) #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/video_deblur/{:08d}.png'.format(img_id),xk) #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/cap_result/1_{:03d}.png'.format(img_id),xk) cv2.imshow('xx', xk) cv2.imshow('img_L', patch_L.astype(np.uint8)) key = cv2.waitKey(-1) if key == ord('q'): break
def main(): # ---------------------------------------- # load kernels # ---------------------------------------- #PSF_grid = np.load('./data/Schuler_PSF01.npz')['PSF'] #PSF_grid = np.load('./data/Schuler_PSF_facade.npz')['PSF'] PSF_grid = np.load('./data/ZEMAX-AC254-075-A-new.npz')['PSF'] #PSF_grid = np.load('./data/Schuler_PSF03.npz')['PSF'] #PSF_grid = np.load('./data/PSF.npz')['PSF'] #print(PSF_grid.shape) PSF_grid = PSF_grid.astype(np.float32) gx,gy = PSF_grid.shape[:2] for xx in range(gx): for yy in range(gy): PSF_grid[xx,yy] = PSF_grid[xx,yy]/np.sum(PSF_grid[xx,yy],axis=(0,1)) #PSF_grid = PSF_grid[:,1:-1,...] # ---------------------------------------- # load model # ---------------------------------------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") loaded_state = torch.load('./usrnet_ZEMAX.pth') #strip_state = strip_prefix_if_present(loaded_state,prefix="p.") model.load_state_dict(loaded_state, strict=True) model.eval() #model.train() for _, v in model.named_parameters(): v.requires_grad = False # v.requires_grad = False model = model.to(device) for img_id in range(100): img_L = cv2.imread('/home/xiu/workspace/UABC/ICCV2021/video/{:08d}.bmp'.format(img_id)) img_L = img_L.astype(np.float32) img_E = np.zeros_like(img_L) weight_E = np.zeros_like(img_L) patch_size = 2*128 num_patch = 2 p_size = patch_size//num_patch expand = PSF_grid.shape[2] ab_numpy = np.loadtxt('ab_ZEMAX.txt').astype(np.float32).reshape(6,8,16,3) ab = torch.tensor(ab_numpy,device=device,requires_grad=False) #save img_L t0 = time.time() #while running: for px_start in range(0,6-2+1,2): for py_start in range(0,8-2+1,2): PSF_patch = PSF_grid[px_start:px_start+num_patch,py_start:py_start+num_patch] patch_L = img_L[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] block_expand = expand #block_expand = 1 if block_expand > 0: patch_L_wrap = util_deblur.wrap_boundary_liu(patch_L,(patch_size+block_expand*2,patch_size+block_expand*2)) #centralize patch_L_wrap = np.hstack((patch_L_wrap[:,-block_expand:,:],patch_L_wrap[:,:patch_size+block_expand,:])) patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:,:,:],patch_L_wrap[:patch_size+block_expand,:,:])) else: patch_L_wrap = patch_L if block_expand>0: x = util.uint2single(patch_L_wrap) else: x = util.uint2single(patch_L) x = util.single2tensor4(x) # x_blocky = torch.cat(torch.chunk(x,num_patch,dim=2),dim=0) # x_blocky = torch.cat(torch.chunk(x_blocky,num_patch,dim=3),dim=0) k_all = [] for w_ in range(num_patch): for h_ in range(num_patch): k_all.append(util.single2tensor4(PSF_patch[h_,w_])) k = torch.cat(k_all,dim=0) [x,k] = [el.to(device) for el in [x,k]] cd = F.softplus(ab[px_start:px_start+num_patch,py_start:py_start+num_patch]) cd = cd.view(num_patch**2,2*8,3) x_E = model.forward_patchwise(x,k,cd,[num_patch,num_patch],[patch_size//num_patch,patch_size//num_patch]) patch_L = patch_L_wrap.astype(np.uint8) patch_E = util.tensor2uint(x_E) patch_E_all = [util.tensor2uint(pp) for pp in x_E] #patch_E_z = np.hstack((patch_E_all[::2])) #patch_E_x = np.hstack((patch_E_all[1::2])) #patch_E_show = np.vstack((patch_E_z,patch_E_x)) #if block_expand>0: # show = np.hstack((patch_L[block_expand:-block_expand,block_expand:-block_expand],patch_E)) #else: # show = np.hstack((patch_L,patch_E)) #get kernel for i in range(8): img_E_deconv[i][px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += patch_E_all[-2][expand:-expand,expand:-expand] img_E_denoise[i][px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += patch_E_all[-1][expand:-expand,expand:-expand] weight_E[px_start*p_size:(px_start+num_patch)*p_size,py_start*p_size:py_start*p_size+num_patch*p_size,:] += 1.0 #cv2.imshow('stage',patch_E_show) #cv2.imshow('HL',show) #cv2.imshow('RGB',rgb) #key = cv2.waitKey(-1) #if key==ord('n'): # break t1 = time.time() print(t1-t0) img_E = img_E/weight_E img_E_deconv = [pp/weight_E for pp in img_E_deconv] img_E_denoise = [pp/weight_E for pp in img_E_denoise] # print(i) xk = img_E_denoise[-1] # #zk = zk.astype(np.uint8) xk = xk.astype(np.uint8) #cv2.imwrite('/home/xiu/workspace/UABC/ICCV2021/video_deblur/{:08d}.png'.format(img_id),xk) cv2.imshow('xx',xk) cv2.imshow('img_L',img_L.astype(np.uint8)) cv2.waitKey(-1)
def main(): # ---------------------------------------- # load kernels # ---------------------------------------- PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF'] PSF_grid = PSF_grid.astype(np.float32) gx, gy = PSF_grid.shape[:2] for xx in range(gx): for yy in range(gy): PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy], axis=(0, 1)) # ---------------------------------------- # load model # ---------------------------------------- stage = 8 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=stage, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") model_code = 'iter17000' loaded_state = torch.load( '/home/xiu/databag/deblur/models/ZEMAX/uabcnet_{}.pth'.format( model_code)) model.load_state_dict(loaded_state, strict=True) model.eval() for _, v in model.named_parameters(): v.requires_grad = False model = model.to(device) img_names = glob.glob( '/home/xiu/databag/deblur/ICCV2021/suo_image/*/AC254-075-A-ML-Zemax(ZMX).bmp' ) img_names.sort() for img_id, img_name in enumerate(img_names): img_L = cv2.imread(img_name) img_L = img_L.astype(np.float32) W, H = img_L.shape[:2] num_patch = [6, 8] #positional alpha-beta parameters for HQS ab_numpy = np.loadtxt( '/home/xiu/databag/deblur/models/ZEMAX/ab_{}.txt'.format( model_code)).astype(np.float32).reshape(gx, gy, stage * 2, 3) ab = torch.tensor(ab_numpy, device=device, requires_grad=False) #save img_L t0 = time.time() px_start = 0 py_start = 0 PSF_patch = PSF_grid[px_start:px_start + num_patch[0], py_start:py_start + num_patch[1]] #block_expand = 1 patch_L = img_L[px_start * W // gx:(px_start + num_patch[0]) * W // gx, py_start * H // gy:(py_start + num_patch[1]) * H // gy, :] p_W, p_H = patch_L.shape[:2] expand = max(PSF_grid.shape[2] // 2, p_W // 16) block_expand = expand patch_L_wrap = util_deblur.wrap_boundary_liu( patch_L, (p_W + block_expand * 2, p_H + block_expand * 2)) #centralize patch_L_wrap = np.hstack((patch_L_wrap[:, -block_expand:, :], patch_L_wrap[:, :p_H + block_expand, :])) patch_L_wrap = np.vstack((patch_L_wrap[-block_expand:, :, :], patch_L_wrap[:p_W + block_expand, :, :])) x = util.uint2single(patch_L_wrap) x = util.single2tensor4(x) k_all = [] for h_ in range(num_patch[1]): for w_ in range(num_patch[0]): k_all.append(util.single2tensor4(PSF_patch[w_, h_])) k = torch.cat(k_all, dim=0) [x, k] = [el.to(device) for el in [x, k]] ab_patch = F.softplus(ab[px_start:px_start + num_patch[0], py_start:py_start + num_patch[1]]) cd = [] for h_ in range(num_patch[1]): for w_ in range(num_patch[0]): cd.append(ab_patch[w_:w_ + 1, h_]) cd = torch.cat(cd, dim=0) x_E = model.forward_patchwise(x, k, cd, num_patch, [W // gx, H // gy]) x_E = x_E[..., block_expand:block_expand + p_W, block_expand:block_expand + p_H] patch_L = patch_L_wrap.astype(np.uint8) patch_E = util.tensor2uint(x_E) t1 = time.time() print('[{}/{}]: {} s per frame'.format(img_id, len(img_names), t1 - t0)) xk = patch_E xk = xk.astype(np.uint8) cv2.imshow('res', xk) cv2.imshow('input', patch_L.astype(np.uint8)) key = cv2.waitKey(-1) if key == ord('q'): break
def main(): # ---------------------------------------- # load kernels # ---------------------------------------- PSF_grid = np.load('./data/AC254-075-A-ML-Zemax(ZMX).npz')['PSF'] PSF_grid = PSF_grid.astype(np.float32) gx, gy = PSF_grid.shape[:2] k_tensor = [] for yy in range(gy): for xx in range(gx): PSF_grid[xx, yy] = PSF_grid[xx, yy] / np.sum(PSF_grid[xx, yy], axis=(0, 1)) k_tensor.append(util.single2tensor4(PSF_grid[xx, yy])) k_tensor = torch.cat(k_tensor, dim=0) inv_weight = util_deblur.get_inv_spatial_weight(k_tensor) # ---------------------------------------- # load model # ---------------------------------------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") model.proj.load_state_dict(torch.load('./data/usrnet_pretrain.pth'), strict=True) model.train() for _, v in model.named_parameters(): v.requires_grad = True model = model.to(device) # ---------------------------------------- # load training data # ---------------------------------------- imgs = glob.glob('./DIV2K_train/*.png', recursive=True) imgs.sort() # ---------------------------------------- # positional lambda\mu for HQS # ---------------------------------------- stage = 8 ab_buffer = np.ones((gx, gy, 2 * stage, 3), dtype=np.float32) * 0.1 #ab_buffer[:,:,0,:] = 0.01 ab = torch.tensor(ab_buffer, device=device, requires_grad=True) # ---------------------------------------- # build optimizer # ---------------------------------------- params = [] params += [{"params": [ab], "lr": 0.0005}] for key, value in model.named_parameters(): params += [{"params": [value], "lr": 0.0001}] optimizer = torch.optim.Adam(params, lr=0.0001, betas=(0.9, 0.999)) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9) patch_size = [128, 128] expand = PSF_grid.shape[2] // 2 patch_num = [2, 2] global_iter = 0 running = True while running: #alpha.beta img_idx = np.random.randint(len(imgs)) img = imgs[img_idx] img_H = cv2.imread(img) w, h = img_H.shape[:2] #focus on the edges mode = np.random.randint(5) px_start = np.random.randint(0, gx - patch_num[0] + 1) py_start = np.random.randint(0, gy - patch_num[1] + 1) if mode == 0: px_start = 0 if mode == 1: px_start = gx - patch_num[0] if mode == 2: py_start = 0 if mode == 3: py_start = gy - patch_num[1] x_start = np.random.randint( 0, w - patch_size[0] * patch_num[0] - expand * 2 + 1) y_start = np.random.randint( 0, h - patch_size[1] * patch_num[1] - expand * 2 + 1) PSF_patch = PSF_grid[px_start:px_start + patch_num[0], py_start:py_start + patch_num[1]] patch_H = img_H[x_start:x_start+patch_size[0]*patch_num[0]+expand*2,\ y_start:y_start+patch_size[1]*patch_num[1]+expand*2] patch_L = util_deblur.blockConv2d(patch_H, PSF_patch, expand) block_expand = max(patch_size[0] // 8, expand) patch_L_wrap = util_deblur.wrap_boundary_liu( patch_L, (patch_size[0] * patch_num[0] + block_expand * 2, patch_size[1] * patch_num[1] + block_expand * 2)) patch_L_wrap = np.hstack( (patch_L_wrap[:, -block_expand:, :], patch_L_wrap[:, :patch_size[1] * patch_num[1] + block_expand, :])) patch_L_wrap = np.vstack( (patch_L_wrap[-block_expand:, :, :], patch_L_wrap[:patch_size[0] * patch_num[0] + block_expand, :, :])) x = util.uint2single(patch_L_wrap) x = util.single2tensor4(x) x_gt = util.uint2single(patch_H[expand:-expand, expand:-expand]) x_gt = util.single2tensor4(x_gt) inv_weight_patch = torch.ones_like(x_gt) k_local = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): inv_weight_patch[0, 0, w_ * patch_size[0]:(w_ + 1) * patch_size[0], h_ * patch_size[1]:(h_ + 1) * patch_size[1]] = inv_weight[w_ + h_ * patch_num[0], 0] inv_weight_patch[0, 1, w_ * patch_size[0]:(w_ + 1) * patch_size[0], h_ * patch_size[1]:(h_ + 1) * patch_size[1]] = inv_weight[w_ + h_ * patch_num[0], 1] inv_weight_patch[0, 2, w_ * patch_size[0]:(w_ + 1) * patch_size[0], h_ * patch_size[1]:(h_ + 1) * patch_size[1]] = inv_weight[w_ + h_ * patch_num[0], 2] k_local.append(k_tensor[w_ + h_ * patch_num[0]:w_ + h_ * patch_num[0] + 1]) k = torch.cat(k_local, dim=0) [x, x_gt, k, inv_weight_patch ] = [el.to(device) for el in [x, x_gt, k, inv_weight_patch]] ab_patch = F.softplus(ab[px_start:px_start + patch_num[0], py_start:py_start + patch_num[1]]) cd = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): cd.append(ab_patch[w_:w_ + 1, h_]) cd = torch.cat(cd, dim=0) x_E = model.forward_patchwise(x, k, cd, patch_num, patch_size) predict = x_E[...,block_expand:block_expand+patch_size[0]*patch_num[0],\ block_expand:block_expand+patch_size[1]*patch_num[1]] loss = F.l1_loss(predict.div(inv_weight_patch), x_gt.div(inv_weight_patch)) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() print('iter:{},loss {}'.format(global_iter + 1, loss.item())) patch_L = patch_L_wrap.astype(np.uint8) patch_E = util.tensor2uint(x_E)[block_expand:-block_expand, block_expand:-block_expand] show = np.hstack((patch_H[expand:-expand, expand:-expand], patch_L[block_expand:-block_expand, block_expand:-block_expand], patch_E)) cv2.imshow('HL', show) key = cv2.waitKey(1) global_iter += 1 #change the save period if global_iter % 100 == 0: ab_numpy = ab.detach().cpu().numpy().flatten() torch.save( model.state_dict(), './ZEMAX_model/usrnet_ZEMAX_iter{}.pth'.format(global_iter)) np.savetxt('./ZEMAX_model/ab_ZEMAX_iter{}.txt'.format(global_iter), ab_numpy) if key == ord('q'): running = False break ab_numpy = ab.detach().cpu().numpy().flatten() torch.save(model.state_dict(), './ZEMAX_model/usrnet_ZEMAX.pth') np.savetxt('./ZEMAX_model/ab_ZEMAX.txt', ab_numpy)
def main(): #0. global config #scale factor sf = 4 stage = 8 patch_size = [32, 32] patch_num = [3, 3] #1. local PSF #shape: gx,gy,kw,kw,3 all_PSFs = load_kernels('./data') #2. local model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2, sf=sf, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") #model.proj.load_state_dict(torch.load('./data/usrnet_pretrain.pth'),strict=True) model.train() for _, v in model.named_parameters(): v.requires_grad = True model = model.to(device) #positional lambda, mu for HQS, set as free trainable parameters here. ab_buffer = np.ones( (patch_num[0], patch_num[1], 2 * stage, 3), dtype=np.float32) * 0.1 ab = torch.tensor(ab_buffer, device=device, requires_grad=True) params = [] params += [{"params": [ab], "lr": 0.0005}] for key, value in model.named_parameters(): params += [{"params": [value], "lr": 0.0001}] optimizer = torch.optim.Adam(params, lr=0.0001, betas=(0.9, 0.999)) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9) #3.load training data imgs_H = glob.glob('/home/xiu/databag/deblur/images/DIV2K_train/*.png', recursive=True) imgs_L = glob.glob('/home/xiu/databag/deblur/images/DIV2K_lr/*.png', recursive=True) imgs_H.sort() imgs_L.sort() global_iter = 0 N_maxiter = 200000 #def get_train_pairs() for i in range(N_maxiter): t0 = time.time() #draw random image. img_idx = np.random.randint(len(imgs_H)) img_H = cv2.imread(imgs_H[img_idx]) #img2 = imgs_L[img_idx] #img_L = cv2.imread(img2) #draw random patch from image #a. without img_L #draw random kernel PSF_grid = draw_random_kernel(all_PSFs, patch_num) patch_L, patch_H, patch_psf = draw_training_pair( img_H, PSF_grid, sf, patch_num, patch_size) #b. with img_L #patch_L, patch_H, patch_psf,px_start, py_start,block_expand = draw_training_pair(img_H, PSF_grid, sf, patch_num, patch_size, img_L) t_data = time.time() - t0 x = util.uint2single(patch_L) x = util.single2tensor4(x) x_gt = util.uint2single(patch_H) x_gt = util.single2tensor4(x_gt) k_local = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): k_local.append(util.single2tensor4(patch_psf[w_, h_])) k = torch.cat(k_local, dim=0) [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]] ab_patch = F.softplus(ab) ab_patch_v = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): ab_patch_v.append(ab_patch[w_:w_ + 1, h_]) ab_patch_v = torch.cat(ab_patch_v, dim=0) x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num, [patch_size[0], patch_size[1]], sf) loss = F.l1_loss(x_E, x_gt) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() t_iter = time.time() - t0 - t_data print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'. format(global_iter + 1, loss.item(), t_data, t_iter)) patch_L = cv2.resize(patch_L, dsize=None, fx=sf, fy=sf, interpolation=cv2.INTER_NEAREST) #patch_L = patch_L[block_expand*sf:-block_expand*sf,block_expand*sf:-block_expand*sf] patch_E = util.tensor2uint((x_E)) show = np.hstack((patch_H, patch_L, patch_E)) cv2.imshow('H,L,E', show) key = cv2.waitKey(1) global_iter += 1 # for logging model weight. # if global_iter % 100 ==0: # torch.save(model.state_dict(),'./logs/uabcnet_{}.pth'.format(global_iter)) if key == ord('q'): break if key == ord('s'): torch.save(model.state_dict(), './logs/uabcnet.pth') torch.save(model.state_dict(), './logs/uabcnet.pth')
def main(): #0. global config sf = 4 stage = 8 patch_size = [32, 32] patch_num = [2, 2] #1. local PSF all_PSFs = load_kernels('./data') #2. load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2, sf=sf, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") model.load_state_dict(torch.load('./data/uabcnet_final.pth'), strict=True) model.train() for _, v in model.named_parameters(): v.requires_grad = True model = model.to(device) #3. set up discriminator model_D = gan.PatchDiscriminator(5) model_D = model_D.to(device) gan_loss = gan.GANLoss(mode='lsgan') gan_loss = gan_loss.to(device) fake_images = ImagePool(16) #positional lambda, mu for HQS. ab_buffer = np.zeros((patch_num[0], patch_num[1], 2 * stage, 3)) ab_buffer[:, :, ::2, :] = 0.01 ab_buffer[:, :, 1::2, :] = 0.1 ab = torch.tensor(ab_buffer, dtype=torch.float32, device=device, requires_grad=True) params = [] params += [{"params": [ab], "lr": 5e-4}] for key, value in model.named_parameters(): params += [{"params": [value], "lr": 1e-5}] # params_D = [] params_D += list(model_D.parameters()) optimizer = torch.optim.Adam(params, lr=1e-4, betas=(0.9, 0.999)) optimizer_D = torch.optim.Adam(params_D, lr=1e-4, betas=(0.9, 0.999)) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9) #3.load training data imgs_H = glob.glob('/home/xiu/databag/deblur/images/DIV2K_train/*.png', recursive=True) imgs_H.sort() global_iter = 0 N_maxiter = 200000 PSF_grid = draw_random_kernel(all_PSFs) for i in range(N_maxiter): t0 = time.time() img_idx = np.random.randint(len(imgs_H)) img_H = cv2.imread(imgs_H[img_idx]) #draw random kernel patch_L, patch_H, patch_psf = draw_training_pair( img_H, PSF_grid, sf, patch_num, patch_size) t_data = time.time() - t0 x = util.uint2single(patch_L) x = util.single2tensor4(x) x_gt = util.uint2single(patch_H) x_gt = util.single2tensor4(x_gt) k_local = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): k_local.append(util.single2tensor4(patch_psf[w_, h_])) k = torch.cat(k_local, dim=0) [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]] ab_patch = F.softplus(ab) ab_patch_v = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): ab_patch_v.append(ab_patch[w_:w_ + 1, h_]) ab_patch_v = torch.cat(ab_patch_v, dim=0) x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num, [patch_size[0], patch_size[1]], sf) loss_l1 = F.l1_loss(x_E, x_gt) loss_gan = gan_loss(model_D(x_E), True) loss = loss_l1 + loss_gan optimizer.zero_grad() loss.backward() optimizer.step() pred_real = model_D(x_gt) loss_D_real = gan_loss(pred_real, True) fake = fake_images.query(x_E) pred_fake = model_D(fake.detach()) loss_D_fake = gan_loss(pred_fake, False) loss_D = (loss_D_fake + loss_D_real) * 0.5 optimizer_D.zero_grad() loss_D.backward() optimizer_D.step() scheduler.step() t_iter = time.time() - t0 - t_data print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'. format(global_iter + 1, loss.item(), t_data, t_iter)) patch_L = cv2.resize(patch_L, dsize=None, fx=sf, fy=sf, interpolation=cv2.INTER_NEAREST) patch_E = util.tensor2uint((x_E)) show = np.hstack((patch_H, patch_L, patch_E)) cv2.imshow('H,L,E', show) key = cv2.waitKey(1) global_iter += 1 if key == ord('q'): break ab_numpy = ab.detach().cpu().numpy().flatten() torch.save(model.state_dict(), './data/uabcnet_finetune.pth') np.savetxt('./data/ab_finetune.txt', ab_numpy)
def main(): #0. global config #scale factor sf = 4 stage = 5 patch_size = [32,32] patch_num = [2,2] #1. local PSF #shape: gx,gy,kw,kw,3 all_PSFs = load_kernels('./data') #2. local model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=5, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2,sf=sf, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") model.load_state_dict(torch.load('./logs/uabcnet_final.pth'),strict=True) model.train() for _, v in model.named_parameters(): v.requires_grad = True model = model.to(device) #positional lambda, mu for HQS, set as free trainable parameters here. #ab_buffer = np.loadtxt('./data/ab.txt').reshape((patch_num[0],patch_num[1],2*stage,3)).astype(np.float32) ab_pretrain = np.loadtxt('./logs/ab_pretrain.txt').reshape((1,1,2*stage,3)).astype(np.float32) ab_buffer = np.ones((patch_num[0],patch_num[1],2*stage,3),dtype=np.float32) for xx in range(patch_num[0]): for yy in range(patch_num[1]): ab_buffer[xx,yy] = ab_pretrain[0,0] ab = torch.tensor(ab_buffer,device=device,requires_grad=True) params = [] params += [{"params":[ab],"lr":0.0001}] for key,value in model.named_parameters(): params += [{"params":[value],"lr":1e-6}] optimizer = torch.optim.Adam(params,lr=0.0001,betas=(0.9,0.999)) scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=1000,gamma=0.9) #3.load training data imgs_H = glob.glob('./DIV2K_train/*.png',recursive=True) imgs_H.sort() global_iter = 0 all_PSNR = [] N_maxiter = 4000 PSF_grid = draw_random_kernel(all_PSFs,patch_num) #def get_train_pairs() for i in range(N_maxiter): t0 = time.time() #draw random image. img_idx = np.random.randint(len(imgs_H)) img_H = cv2.imread(imgs_H[img_idx]) patch_L,patch_H,patch_psf = draw_training_pair(img_H,PSF_grid,sf,patch_num,patch_size) #b. with img_L #patch_L, patch_H, patch_psf,px_start, py_start,block_expand = draw_training_pair(img_H, PSF_grid, sf, patch_num, patch_size, img_L) t_data = time.time()-t0 x = util.uint2single(patch_L) x = util.single2tensor4(x) x_gt = util.uint2single(patch_H) x_gt = util.single2tensor4(x_gt) k_local = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): k_local.append(util.single2tensor4(patch_psf[w_,h_])) k = torch.cat(k_local,dim=0) [x,x_gt,k] = [el.to(device) for el in [x,x_gt,k]] ab_patch = F.softplus(ab) ab_patch_v = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): ab_patch_v.append(ab_patch[w_:w_+1,h_]) ab_patch_v = torch.cat(ab_patch_v,dim=0) x_E = model.forward_patchwise_SR(x,k,ab_patch_v,patch_num,[patch_size[0],patch_size[1]],sf) loss = F.l1_loss(x_E,x_gt) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() t_iter = time.time() - t0 - t_data print('[iter:{}] loss:{:.4f}, data_time:{:.2f}s, net_time:{:.2f}s'.format(global_iter+1,loss.item(),t_data,t_iter)) patch_L = cv2.resize(patch_L,dsize=None,fx=sf,fy=sf,interpolation=cv2.INTER_NEAREST) #patch_L = patch_L[block_expand*sf:-block_expand*sf,block_expand*sf:-block_expand*sf] patch_E = util.tensor2uint((x_E)) show = np.hstack((patch_H,patch_L,patch_E)) cv2.imshow('H,L,E',show) key = cv2.waitKey(1) global_iter+= 1 if i % 1000 ==0: cv2.imwrite(os.path.join('./result', 'test' , 'resultE-{:04d}.png'.format(i + 1)), patch_E) cv2.imwrite(os.path.join('./result', 'test', 'resultL-{:04d}.png'.format(i + 1)), patch_L) cv2.imwrite(os.path.join('./result', 'test', 'resultH-{:04d}.png'.format(i + 1)), patch_H) # if key==ord('q'): # break # if key==ord('s'): # ab_numpy = ab.detach().cpu().numpy().flatten() # np.savetxt('./data/ab.txt',ab_numpy) ab_numpy = ab.detach().cpu().numpy().flatten() torch.save(model.state_dict(),'./data/uabcnet_finetune.pth') np.savetxt('./data/ab_finetune.txt',ab_numpy)
def main(): #0. global config #scale factor sf = 4 stage = 5 patch_size = [32, 32] patch_num = [2, 2] #1. local PSF #shape: gx,gy,kw,kw,3 all_PSFs = load_kernels('./data') #2. local model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=5, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2, sf=sf, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") #loaded_state_dict= torch.load('./data/uabcnet_final.pth') loaded_state_dict = torch.load('./data/uabcnet_finetune.pth') model.load_state_dict(loaded_state_dict, strict=True) model.eval() for _, v in model.named_parameters(): v.requires_grad = False model = model.to(device) #positional lambda, mu for HQS, set as free trainable parameters here. ab_buffer = np.loadtxt('./data/ab_finetune.txt').reshape( (patch_num[0], patch_num[1], 2 * stage, 3)).astype(np.float32) #ab[2x2,2*stage,3] #ab_buffer = np.ones((patch_num[0],patch_num[1],2*stage,3),dtype=np.float32)*0.1 ab = torch.tensor(ab_buffer, device=device, requires_grad=False) ab = F.softplus(ab) #3.load training data imgs_H = glob.glob('./DIV2K_train/*.png', recursive=True) imgs_H.sort() global_iter = 0 N_maxiter = 1000 PSF_grid = using_AC254_lens(all_PSFs, patch_num) all_PSNR = [] out_folder = 'finetune' for i in range(N_maxiter): #draw random image. img_idx = np.random.randint(len(imgs_H)) img_H = cv2.imread(imgs_H[img_idx]) img_H = np.pad(img_H, [(12, 12), (12, 12), (0, 0)]) croppatch = imgpatch(img_H, 280, 280, 24) patches = croppatch.crop(img_H, 1) patch_E_list = [] patch_L_list = [] # patch_L,patch_H,patch_psf = draw_training_pair(img_H,PSF_grid,sf,patch_num,patch_size) for piece in range(len(patches)): patch_L, patch_H, patch_psf = draw_testing_pair( patches[piece], PSF_grid, sf, patch_num, patch_size) x = util.uint2single(patch_L) x = util.single2tensor4(x) x_gt = util.uint2single(patch_H) x_gt = util.single2tensor4(x_gt) k_local = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): k_local.append(util.single2tensor4(patch_psf[w_, h_])) k = torch.cat(k_local, dim=0) [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]] ab_patch = F.softplus(ab) ab_patch_v = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): ab_patch_v.append(ab_patch[w_:w_ + 1, h_]) ab_patch_v = torch.cat(ab_patch_v, dim=0) x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num, [patch_size[0], patch_size[1]], sf) patch_L = cv2.resize(patch_L, dsize=None, fx=sf, fy=sf, interpolation=cv2.INTER_NEAREST) patch_E = util.tensor2uint((x_E)) patch_E_list.append(patch_E[np.newaxis, :]) patch_L_list.append((patch_L[None, ...])) print(piece) img_E = croppatch.merge(patch_E_list) # croppatch_E = imgpatch(np.zeros_like(img_H), 256, 256, 0) # croppatch_E.crop(np.zeros_like(img_H),1) # img_E = croppatch_E.merge(patch_E_list) psnr, ssim = cal_psnrssim(img_E, img_H, 255) print(psnr) print(ssim) cv2.imwrite( os.path.join('./result', out_folder, 'resultE-{:04d}.png'.format(i + 1)), img_E) cv2.imwrite( os.path.join('./result', out_folder, 'resultH-{:04d}.png'.format(i + 1)), img_H) all_PSNR.append(psnr) #show = np.hstack((patch_H,patch_L,patch_E)) np.savetxt(os.path.join('./result', out_folder, 'psnr.txt'), all_PSNR)
def main(): # 0. global config # scale factor sf = 4 stage = 5 patch_size = [32, 32] patch_num = [2, 2] # 1. local PSF # shape: gx,gy,kw,kw,3 all_PSFs = load_kernels('./data') # 2. local model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = net(n_iter=5, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=2, sf=sf, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose") loaded_state_dict = torch.load('./logs/uabcnet_final.pth') # loaded_state_dict = torch.load('./data/uabcnet_finetune.pth') model.load_state_dict(loaded_state_dict, strict=True) model.eval() for _, v in model.named_parameters(): v.requires_grad = False model = model.to(device) # positional lambda, mu for HQS, set as free trainable parameters here. ab_pretrain = np.loadtxt('./logs/ab_pretrain.txt').reshape( (1, 1, 2 * stage, 3)).astype(np.float32) ab_buffer = np.ones((patch_num[0], patch_num[1], 2 * stage, 3), dtype=np.float32) for xx in range(patch_num[0]): for yy in range(patch_num[1]): ab_buffer[xx, yy] = ab_pretrain[0, 0] ab = torch.tensor(ab_buffer, device=device, requires_grad=False) # ab = F.softplus(ab) # 3.load training data imgs_H = glob.glob('./DIV2K_train/*.png', recursive=True) imgs_H.sort() global_iter = 0 N_maxiter = 1000 PSF_grid = using_AC254_lens(all_PSFs, patch_num) all_PSNR = [] for i in range(N_maxiter): # draw random image. img_idx = np.random.randint(len(imgs_H)) img_H = cv2.imread(imgs_H[img_idx]) patch_L, patch_H, patch_psf = draw_training_pair( img_H, PSF_grid, sf, patch_num, patch_size) x = util.uint2single(patch_L) x = util.single2tensor4(x) x_gt = util.uint2single(patch_H) x_gt = util.single2tensor4(x_gt) k_local = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): k_local.append(util.single2tensor4(patch_psf[w_, h_])) k = torch.cat(k_local, dim=0) [x, x_gt, k] = [el.to(device) for el in [x, x_gt, k]] ab_patch = F.softplus(ab) ab_patch_v = [] for h_ in range(patch_num[1]): for w_ in range(patch_num[0]): ab_patch_v.append(ab_patch[w_:w_ + 1, h_]) ab_patch_v = torch.cat(ab_patch_v, dim=0) x_E = model.forward_patchwise_SR(x, k, ab_patch_v, patch_num, [patch_size[0], patch_size[1]], sf) patch_L = cv2.resize(patch_L, dsize=None, fx=sf, fy=sf, interpolation=cv2.INTER_NEAREST) patch_E = util.tensor2uint((x_E)) psnr = cv2.PSNR(patch_E, patch_H) all_PSNR.append(psnr) show = np.hstack((patch_H, patch_L, patch_E)) if i % 250 == 0: cv2.imwrite( os.path.join('./result', 'finetune', 'result-{:04d}.png'.format(i + 1)), show) cv2.imwrite( os.path.join('./result', 'finetune', 'result-{:04d}.png'.format(i + 1)), show) np.savetxt(os.path.join('./result', 'finetune', 'psnr.txt'), all_PSNR)