def inpainting_lista(missing_pixels, dicsize, n_nonzero_coefs, blksize, overlap, maxiter, grayimg): img_ori, masked_img, patchvec, maskvec = generate_inpainting_data( missing_pixels=missing_pixels, blksize=blksize, overlap=overlap, grayimg=grayimg) # settings h, w = img_ori.shape[0], img_ori.shape[1] m = patchvec.shape[0] n = dicsize s = n_nonzero_coefs / n print('[INFO] Dictionary size ({}, {}).'.format(m, n)) print('[INFO] Sparsity {}.'.format(n_nonzero_coefs)) print('[INFO] OMP max iter {}.'.format(maxiter)) D = get_init_dict(dicsize=dicsize, blksize=blksize) Y = patchvec # training model = train_lista(D, Y, maskvec, s, img_ori, blksize, overlap) # testing with torch.no_grad(): Y_pt = torch.from_numpy(Y.T / 255.0).cuda() coef_pt = model(Y_pt) coef = coef_pt.cpu().detach().numpy().T # testing recovered_img = patch2im_for_inpainting(patch_vecs=255.0 * np.matmul(D, coef), mask_vecs=maskvec, imgsize=img_ori.shape, blksize=blksize, overlap=overlap) # save results if not os.path.exists('./save_lista'): os.mkdir('./save_lista') np.save('./save_lista/img_ori', img_ori) np.save('./save_lista/masked_img', masked_img) np.save('./save_lista/recovered_img', recovered_img) return img_ori, masked_img, recovered_img
def inpainting_ksvd(missing_pixels, dicsize, n_nonzero_coefs, blksize, overlap, maxiter, grayimg=False): img_ori, masked_img, patchvec, maskvec = generate_inpainting_data(missing_pixels=missing_pixels, blksize=blksize, overlap=overlap, grayimg=grayimg) # img_ori, masked_img, patchvec, maskvec = load_inpainting_image(blksize=blksize, overlap=overlap) # settings h, w = img_ori.shape[0], img_ori.shape[1] m = patchvec.shape[0] n = dicsize print('[INFO] Dictionary size ({}, {}).'.format(m, n)) print('[INFO] Sparsity {}.'.format(n_nonzero_coefs)) print('[INFO] OMP max iter {}.'.format(maxiter)) load_results = True if load_results: print('[INFO] Load OMP results.') if grayimg: dict = np.load('./D.npy') coef = np.load('./coef.npy') recovered_img = patch2im_for_inpainting(patch_vecs=np.matmul(dict, coef), mask_vecs=maskvec, imgsize=(h, w), blksize=blksize, overlap=overlap) else: dictR = np.load('./DR.npy') dictG = np.load('./DG.npy') dictB = np.load('./DB.npy') coefR = np.load('./coefR.npy') coefG = np.load('./coefG.npy') coefB = np.load('./coefB.npy') recovered_imgR = patch2im_for_inpainting(patch_vecs=np.matmul(dictR, coefR), mask_vecs=maskvec, imgsize=(h, w), blksize=blksize, overlap=overlap) recovered_imgG = patch2im_for_inpainting(patch_vecs=np.matmul(dictG, coefG), mask_vecs=maskvec, imgsize=(h, w), blksize=blksize, overlap=overlap) recovered_imgB = patch2im_for_inpainting(patch_vecs=np.matmul(dictB, coefB), mask_vecs=maskvec, imgsize=(h, w), blksize=blksize, overlap=overlap) recovered_img = np.stack((recovered_imgR, recovered_imgG, recovered_imgB), axis=2) return img_ori, masked_img, recovered_img else: # OMP for inpainting print('[INFO] Run OMP for image inpainting') if grayimg: D = get_init_dict(dicsize=dicsize, blksize=blksize) Y = patchvec newD, coef = inpainting_ksvd_once(Y, D, maskvec, n_nonzero_coefs=50, maxiter=maxiter) print('[INFO] Saving dictionary ...') np.save('./D', newD) print('[INFO] Saving coefficients ...') np.save('./coef', coef) # recover data print('[INFO] Reconstruction ...') recovered_img = patch2im_for_inpainting(patch_vecs=np.matmul(newD, coef), mask_vecs=maskvec, imgsize=img_ori.shape, blksize=blksize, overlap=overlap) else: D = get_init_dict(dicsize=dicsize, blksize=blksize) RGBname = ['R', 'G', 'B'] recovered_img = np.zeros_like(img_ori) # do R-G-B for channel in range(3): Y = patchvec[:, :, channel] newD, coef = inpainting_ksvd_once(Y, D, maskvec, n_nonzero_coefs=50, maxiter=maxiter) print('[INFO] Saving dictionary ...') np.save('./D{}'.format(RGBname[channel]), newD) print('[INFO] Saving coefficients ...') np.save('./coef{}'.format(RGBname[channel]), coef) # recover data print('[INFO] Reconstruction ...') recovered_img_c = patch2im_for_inpainting(patch_vecs=np.matmul(newD, coef), mask_vecs=maskvec, imgsize=img_ori.shape, blksize=blksize, overlap=overlap) recovered_img[:, :, channel] = recovered_img_c return img_ori, masked_img, recovered_img
def train_lista(D, Y, maskvec, s, img_ori, blksize, overlap): """ Trianing LISTA network. :param D: size [m, n], e.g [64, 256] :param Y: size [64, K], e.g [64, 10000] :param s: :param epochs: :return: """ # training settings bs = 128 learning_rate = 1e-2 epochs = 50 K = Y.shape[1] iter_epoch = K // bs eig, eig_vector = np.linalg.eig(D.T.dot(D)) L = np.max(eig) # Lipschitz constant theta = s / L m, n = D.shape # normalize image Y = (Y / 255.0) * maskvec # convert the data into tensors Y = Y.T # batch should be the 0-dim in pytorch Y = torch.from_numpy(Y).float().cuda() D = torch.from_numpy(D).float().cuda() maskvec = torch.from_numpy(maskvec.T).cuda() # network initialization ListaNet = LISTA(m, n, D, L, theta, max_iter=30) ListaNet.weights_init() ListaNet = ListaNet.cuda() # build the optimizer and criterion criterion1 = nn.MSELoss() criterion2 = nn.L1Loss() all_zeros = torch.zeros(bs, n).cuda() optimizer = torch.optim.SGD(ListaNet.parameters(), lr=learning_rate, momentum=0.9) psnr_list = np.zeros(epochs) for i_epoch in range(epochs): indexs = np.arange(K) np.random.shuffle(indexs) epoch_loss = 0 epoch_loss1 = 0 epoch_loss2 = 0 for i_iter in range(iter_epoch): Y_batch = Y[indexs[i_iter * bs:(i_iter + 1) * bs]] mask = maskvec[indexs[i_iter * bs:(i_iter + 1) * bs]] # get the outputs X = ListaNet(Y_batch) Y_recons = torch.mm(X, D.T) # compute the losss loss1 = criterion1(Y_batch * mask, Y_recons * mask) loss2 = (1 - s) * criterion2(X, all_zeros) loss = loss1 + loss2 loss.backward() optimizer.step() optimizer.zero_grad() epoch_loss1 += loss1.item() epoch_loss2 += loss2.item() epoch_loss += loss.item() # validation with torch.no_grad(): coef_pt = ListaNet(Y) coef = coef_pt.cpu().detach().numpy().T # testing recovered_img = patch2im_for_inpainting( patch_vecs=255.0 * np.matmul(D.cpu().detach().numpy(), coef), mask_vecs=maskvec, imgsize=img_ori.shape, blksize=blksize, overlap=overlap) psnr = PSNR(img_ori, recovered_img) psnr_list[i_epoch] = psnr print('[INFO] Epoch: {}/{} loss: {:.3f} ({:f}, {:f}) PSNR: {:.2f}'. format(i_epoch, epochs, epoch_loss / iter_epoch, epoch_loss1 / iter_epoch, epoch_loss2 / iter_epoch, psnr)) if psnr > 30: print('[INFO] => ListaNet best PSNR:{:.2f}'.format( psnr_list.max())) return ListaNet print('[INFO] => ListaNet best PSNR:{:.2f}'.format(psnr_list.max())) return ListaNet