def test_FISTA_Denoising(self): print ("FISTA Denoising Poisson Noise Tikhonov") # adapted from demo FISTA_Tikhonov_Poisson_Denoising.py in CIL-Demos repository #loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi')) loader = TestData() data = loader.load(TestData.SHAPES) ig = data.geometry ag = ig N=300 # Create Noisy data with Poisson noise scale = 5 n1 = TestData.random_noise( data.as_array()/scale, mode = 'poisson', seed = 10)*scale noisy_data = ImageData(n1) # Regularisation Parameter alpha = 10 # Setup and run the FISTA algorithm operator = Gradient(ig) fid = KullbackLeibler(b=noisy_data) reg = FunctionOperatorComposition(alpha * L2NormSquared(), operator) x_init = ig.allocate() fista = FISTA(x_init=x_init , f=reg, g=fid) fista.max_iteration = 3000 fista.update_objective_interval = 500 fista.run(verbose=True) rmse = (fista.get_output() - data).norm() / data.as_array().size print ("RMSE", rmse) self.assertLess(rmse, 4.2e-4)
# Using the test data b, different reconstruction methods can now be set up as # demonstrated in the rest of this file. In general all methods need an initial # guess and some algorithm options to be set: x_init = ig.allocate(0.0) opt = {'tol': 1e-4, 'iter': 200} # Create least squares object instance with projector, test data and a constant # coefficient of 0.5. Note it is least squares over all channels: #f = Norm2Sq(Aop,b,c=0.5) f = FunctionOperatorComposition(L2NormSquared(b=b), Aop) # Run FISTA for least squares without regularization FISTA_alg = FISTA() FISTA_alg.set_up(x_init=x_init, f=f, g=ZeroFunction()) FISTA_alg.max_iteration = 2000 FISTA_alg.run(opt['iter']) x_FISTA = FISTA_alg.get_output() # Display reconstruction and criterion ff0, axarrf0 = plt.subplots(1, numchannels) for k in numpy.arange(3): axarrf0[k].imshow(x_FISTA.as_array()[k], vmin=0, vmax=2.5) plt.show() plt.figure() plt.semilogy(FISTA_alg.objective) plt.title('Criterion vs iterations, least squares') plt.show() # FISTA can also solve regularised forms by specifying a second function object # such as 1-norm regularisation with choice of regularisation parameter lam. # Again the regulariser is over all channels:
# Allocate space for the channel-wise reconstruction fista_sol_TV_channel_wise = A3D_chan.volume_geometry.allocate() for i in range(ag.channels): # Setup L2NormSquarred fidelity term, for each channel f = FunctionOperatorComposition( 0.5 * L2NormSquared(b=data.subset(channel=i)), A3D) # Run FISTA fista = FISTA(x_init=x_init, f=f, g=g) fista.max_iteration = 100 fista.update_objective_interval = 50 fista.run(400, verbose=True, callback=show_data_3D) np.copyto(fista_sol_TV_channel_wise.array[i], fista.get_output().array) #%% show reconstruction show_4D_channel_slice(fista_sol_TV_channel_wise, 5, 'FISTA TV channel-wise reconstruction') show_4D_channel_slice(fista_sol_TV_channel_wise, 10, 'FISTA TV channel-wise reconstruction') show_4D_channel_slice(fista_sol_TV_channel_wise, 15, 'FISTA TV channel-wise reconstruction') #%% Coupling Total variation reconstruction in 4D volume. For this case there is no GPU implementation # But we can use another algorithm called PDHG ( primal - dual hybrid gradient) # Set up operators: Projection and Gradient op1 = A3D_chan
sigma = 1 tau = 1 / (sigma * normK**2) pdhg = PDHG(f=f, g=g, operator=operator, tau=tau, sigma=sigma, memopt=True) pdhg.max_iteration = 2000 pdhg.update_objective_interval = 200 pdhg.run(2000, verbose=False) ############################################################################### # Show results plt.figure(figsize=(10, 10)) plt.subplot(2, 1, 1) plt.imshow(pdhg.get_output().as_array()) plt.title('PDHG reconstruction') plt.subplot(2, 1, 2) plt.imshow(fista.get_output().as_array()) plt.title('FISTA reconstruction') plt.show() diff1 = pdhg.get_output() - fista.get_output() plt.imshow(diff1.abs().as_array()) plt.title('Diff PDHG vs FISTA') plt.colorbar() plt.show()
fig, ax = plt.subplots(1, 3) img1 = ax[0].imshow(data.as_array()) ax[0].set_title('Ground Truth') colorbar(img1) img2 = ax[1].imshow(noisy_data.as_array()) ax[1].set_title('Projection Data') colorbar(img2) img3 = ax[2].imshow(back_proj.as_array()) ax[2].set_title('BackProjection') colorbar(img3) plt.tight_layout(h_pad=1.5) fig1, ax1 = plt.subplots(1, 3) img4 = ax1[0].imshow(fista.get_output().as_array()) ax1[0].set_title('LS unconstrained') colorbar(img4) img5 = ax1[1].imshow(fista0.get_output().as_array()) ax1[1].set_title('LS constrained [0,1]') colorbar(img5) img6 = ax1[2].imshow(fista1.get_output().as_array()) ax1[2].set_title('L2-Regularised LS') colorbar(img6) plt.tight_layout(h_pad=1.5) #%% Check with CVX solution import astra import numpy from ccpi.optimisation.operators import SparseFiniteDiff
fista.max_iteration = 3000 fista.update_objective_interval = 500 fista.run(3000, verbose=True) # Show results plt.figure(figsize=(15, 15)) plt.subplot(3, 1, 1) plt.imshow(data.as_array()) plt.title('Ground Truth') plt.colorbar() plt.subplot(3, 1, 2) plt.imshow(noisy_data.as_array()) plt.title('Noisy Data') plt.colorbar() plt.subplot(3, 1, 3) plt.imshow(fista.get_output().as_array()) plt.title('Reconstruction') plt.colorbar() plt.show() plt.plot(np.linspace(0, ig.shape[0], ig.shape[1]), data.as_array()[int(N / 2), :], label='GTruth') plt.plot(np.linspace(0, ig.shape[0], ig.shape[1]), fista.get_output().as_array()[int(N / 2), :], label='Reconstruction') plt.legend() plt.title('Middle Line Profiles') plt.show() #%% Check with CVX solution
max_iteration=10000, update_objective_interval=100) fi.run(verbose=True) ## Show FISTA reconstruction results plt.figure(figsize=(20, 5)) plt.subplot(1, 4, 1) plt.imshow(data.as_array()) plt.title('Ground Truth') plt.colorbar() plt.subplot(1, 4, 2) plt.imshow(noisy_data.as_array()) plt.title('Noisy Data') plt.colorbar() plt.subplot(1, 4, 3) plt.imshow(fi.get_output().as_array()) plt.title('FISTA Reconstruction') plt.colorbar() plt.subplot(1, 4, 4) plt.plot(np.linspace(0, ig.shape[1], ig.shape[1]), data.as_array()[int(ig.shape[0] / 2), :], label='GTruth') plt.plot(np.linspace(0, ig.shape[1], ig.shape[1]), fi.get_output().as_array()[int(ig.shape[0] / 2), :], label='TV reconstruction') plt.legend() plt.title('Middle Line Profiles') plt.show() #%% Use PDHG to solve non-smooth version of problem for comparison
def test_FISTA_cvx(self): if False: if not cvx_not_installable: try: # Problem data. m = 30 n = 20 np.random.seed(1) Amat = np.random.randn(m, n) A = LinearOperatorMatrix(Amat) bmat = np.random.randn(m) bmat.shape = (bmat.shape[0], 1) # A = Identity() # Change n to equal to m. #b = DataContainer(bmat) vg = VectorGeometry(m) b = vg.allocate('random') # Regularization parameter lam = 10 opt = {'memopt': True} # Create object instances with the test data A and b. f = LeastSquares(A, b, c=0.5) g0 = ZeroFunction() # Initial guess #x_init = DataContainer(np.zeros((n, 1))) x_init = vg.allocate() f.gradient(x_init, out = x_init) # Run FISTA for least squares plus zero function. #x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt) fa = FISTA(x_init=x_init, f=f, g=g0) fa.max_iteration = 10 fa.run(10) # Print solution and final objective/criterion value for comparison print("FISTA least squares plus zero function solution and objective value:") print(fa.get_output()) print(fa.get_last_objective()) # Compare to CVXPY # Construct the problem. x0 = Variable(n) objective0 = Minimize(0.5*sum_squares(Amat*x0 - bmat.T[0])) prob0 = Problem(objective0) # The optimal objective is returned by prob.solve(). result0 = prob0.solve(verbose=False, solver=SCS, eps=1e-9) # The optimal solution for x is stored in x.value and optimal objective value # is in result as well as in objective.value print("CVXPY least squares plus zero function solution and objective value:") print(x0.value) print(objective0.value) self.assertNumpyArrayAlmostEqual( numpy.squeeze(x_fista0.array), x0.value, 6) except SolverError as se: print (str(se)) self.assertTrue(True) else: self.assertTrue(cvx_not_installable)
def stest_FISTA_Norm1_cvx(self): if not cvx_not_installable: try: opt = {'memopt': True} # Problem data. m = 30 n = 20 np.random.seed(1) Amat = np.random.randn(m, n) A = LinearOperatorMatrix(Amat) bmat = np.random.randn(m) #bmat.shape = (bmat.shape[0], 1) # A = Identity() # Change n to equal to m. vgb = VectorGeometry(m) vgx = VectorGeometry(n) b = vgb.allocate() b.fill(bmat) #b = DataContainer(bmat) # Regularization parameter lam = 10 opt = {'memopt': True} # Create object instances with the test data A and b. f = LeastSquares(A, b, c=0.5) g0 = ZeroFunction() # Initial guess #x_init = DataContainer(np.zeros((n, 1))) x_init = vgx.allocate() # Create 1-norm object instance g1 = lam * L1Norm() g1(x_init) g1.prox(x_init, 0.02) # Combine with least squares and solve using generic FISTA implementation #x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt) fa = FISTA(x_init=x_init, f=f, g=g1) fa.max_iteration = 10 fa.run(10) # Print for comparison print("FISTA least squares plus 1-norm solution and objective value:") print(fa.get_output()) print(fa.get_last_objective()) # Compare to CVXPY # Construct the problem. x1 = Variable(n) objective1 = Minimize( 0.5*sum_squares(Amat*x1 - bmat.T[0]) + lam*norm(x1, 1)) prob1 = Problem(objective1) # The optimal objective is returned by prob.solve(). result1 = prob1.solve(verbose=False, solver=SCS, eps=1e-9) # The optimal solution for x is stored in x.value and optimal objective value # is in result as well as in objective.value print("CVXPY least squares plus 1-norm solution and objective value:") print(x1.value) print(objective1.value) self.assertNumpyArrayAlmostEqual( numpy.squeeze(x_fista1.array), x1.value, 6) except SolverError as se: print (str(se)) self.assertTrue(True) else: self.assertTrue(cvx_not_installable)
fista.max_iteration = 500 fista.update_objective_interval = 100 fista.run(500, verbose=True) #%% Show results plt.figure(figsize=(10, 10)) plt.suptitle('Reconstructions ', fontsize=16) plt.subplot(2, 2, 1) plt.imshow(cgls.get_output().as_array()) plt.colorbar() plt.title('CGLS reconstruction') plt.subplot(2, 2, 2) plt.imshow(fista.get_output().as_array()) plt.colorbar() plt.title('FISTA reconstruction') plt.subplot(2, 2, 3) plt.imshow(pdhg.get_output().as_array()) plt.colorbar() plt.title('PDHG reconstruction') plt.subplot(2, 2, 4) plt.imshow(recon_cgls_astra.as_array()) plt.colorbar() plt.title('CGLS astra') plt.show()