コード例 #1
0
    def test_FISTA_Denoising(self):
        print ("FISTA Denoising Poisson Noise Tikhonov")
        # adapted from demo FISTA_Tikhonov_Poisson_Denoising.py in CIL-Demos repository
        #loader = TestData(data_dir=os.path.join(sys.prefix, 'share','ccpi'))
        loader = TestData()
        data = loader.load(TestData.SHAPES)
        ig = data.geometry
        ag = ig
        N=300
        # Create Noisy data with Poisson noise
        scale = 5
        n1 = TestData.random_noise( data.as_array()/scale, mode = 'poisson', seed = 10)*scale
        noisy_data = ImageData(n1)

        # Regularisation Parameter
        alpha = 10

        # Setup and run the FISTA algorithm
        operator = Gradient(ig)
        fid = KullbackLeibler(b=noisy_data)
        reg = FunctionOperatorComposition(alpha * L2NormSquared(), operator)

        x_init = ig.allocate()
        fista = FISTA(x_init=x_init , f=reg, g=fid)
        fista.max_iteration = 3000
        fista.update_objective_interval = 500
        fista.run(verbose=True)
        rmse = (fista.get_output() - data).norm() / data.as_array().size
        print ("RMSE", rmse)
        self.assertLess(rmse, 4.2e-4)
コード例 #2
0
# Using the test data b, different reconstruction methods can now be set up as
# demonstrated in the rest of this file. In general all methods need an initial
# guess and some algorithm options to be set:
x_init = ig.allocate(0.0)
opt = {'tol': 1e-4, 'iter': 200}

# Create least squares object instance with projector, test data and a constant
# coefficient of 0.5. Note it is least squares over all channels:
#f = Norm2Sq(Aop,b,c=0.5)
f = FunctionOperatorComposition(L2NormSquared(b=b), Aop)
# Run FISTA for least squares without regularization
FISTA_alg = FISTA()
FISTA_alg.set_up(x_init=x_init, f=f, g=ZeroFunction())
FISTA_alg.max_iteration = 2000
FISTA_alg.run(opt['iter'])
x_FISTA = FISTA_alg.get_output()

# Display reconstruction and criterion
ff0, axarrf0 = plt.subplots(1, numchannels)
for k in numpy.arange(3):
    axarrf0[k].imshow(x_FISTA.as_array()[k], vmin=0, vmax=2.5)
plt.show()

plt.figure()
plt.semilogy(FISTA_alg.objective)
plt.title('Criterion vs iterations, least squares')
plt.show()

# FISTA can also solve regularised forms by specifying a second function object
# such as 1-norm regularisation with choice of regularisation parameter lam.
# Again the regulariser is over all channels:
コード例 #3
0
# Allocate space for the channel-wise reconstruction
fista_sol_TV_channel_wise = A3D_chan.volume_geometry.allocate()

for i in range(ag.channels):

    # Setup L2NormSquarred fidelity term, for each channel
    f = FunctionOperatorComposition(
        0.5 * L2NormSquared(b=data.subset(channel=i)), A3D)

    # Run FISTA
    fista = FISTA(x_init=x_init, f=f, g=g)
    fista.max_iteration = 100
    fista.update_objective_interval = 50
    fista.run(400, verbose=True, callback=show_data_3D)
    np.copyto(fista_sol_TV_channel_wise.array[i], fista.get_output().array)

#%% show reconstruction

show_4D_channel_slice(fista_sol_TV_channel_wise, 5,
                      'FISTA TV channel-wise reconstruction')
show_4D_channel_slice(fista_sol_TV_channel_wise, 10,
                      'FISTA TV channel-wise reconstruction')
show_4D_channel_slice(fista_sol_TV_channel_wise, 15,
                      'FISTA TV channel-wise reconstruction')

#%% Coupling Total variation reconstruction in 4D volume. For this case there is no GPU implementation
# But we can use another algorithm called PDHG ( primal - dual hybrid gradient)

# Set up operators: Projection and Gradient
op1 = A3D_chan
コード例 #4
0
sigma = 1
tau = 1 / (sigma * normK**2)

pdhg = PDHG(f=f, g=g, operator=operator, tau=tau, sigma=sigma, memopt=True)
pdhg.max_iteration = 2000
pdhg.update_objective_interval = 200
pdhg.run(2000, verbose=False)
###############################################################################

# Show results

plt.figure(figsize=(10, 10))

plt.subplot(2, 1, 1)
plt.imshow(pdhg.get_output().as_array())
plt.title('PDHG reconstruction')

plt.subplot(2, 1, 2)
plt.imshow(fista.get_output().as_array())
plt.title('FISTA reconstruction')

plt.show()

diff1 = pdhg.get_output() - fista.get_output()

plt.imshow(diff1.abs().as_array())
plt.title('Diff PDHG vs FISTA')
plt.colorbar()
plt.show()
コード例 #5
0

fig, ax = plt.subplots(1, 3)
img1 = ax[0].imshow(data.as_array())
ax[0].set_title('Ground Truth')
colorbar(img1)
img2 = ax[1].imshow(noisy_data.as_array())
ax[1].set_title('Projection Data')
colorbar(img2)
img3 = ax[2].imshow(back_proj.as_array())
ax[2].set_title('BackProjection')
colorbar(img3)
plt.tight_layout(h_pad=1.5)

fig1, ax1 = plt.subplots(1, 3)
img4 = ax1[0].imshow(fista.get_output().as_array())
ax1[0].set_title('LS unconstrained')
colorbar(img4)
img5 = ax1[1].imshow(fista0.get_output().as_array())
ax1[1].set_title('LS constrained [0,1]')
colorbar(img5)
img6 = ax1[2].imshow(fista1.get_output().as_array())
ax1[2].set_title('L2-Regularised LS')
colorbar(img6)
plt.tight_layout(h_pad=1.5)

#%% Check with CVX solution

import astra
import numpy
from ccpi.optimisation.operators import SparseFiniteDiff
コード例 #6
0
fista.max_iteration = 3000
fista.update_objective_interval = 500
fista.run(3000, verbose=True)

# Show results
plt.figure(figsize=(15, 15))
plt.subplot(3, 1, 1)
plt.imshow(data.as_array())
plt.title('Ground Truth')
plt.colorbar()
plt.subplot(3, 1, 2)
plt.imshow(noisy_data.as_array())
plt.title('Noisy Data')
plt.colorbar()
plt.subplot(3, 1, 3)
plt.imshow(fista.get_output().as_array())
plt.title('Reconstruction')
plt.colorbar()
plt.show()

plt.plot(np.linspace(0, ig.shape[0], ig.shape[1]),
         data.as_array()[int(N / 2), :],
         label='GTruth')
plt.plot(np.linspace(0, ig.shape[0], ig.shape[1]),
         fista.get_output().as_array()[int(N / 2), :],
         label='Reconstruction')
plt.legend()
plt.title('Middle Line Profiles')
plt.show()

#%% Check with CVX solution
コード例 #7
0
           max_iteration=10000,
           update_objective_interval=100)
fi.run(verbose=True)

## Show FISTA reconstruction results
plt.figure(figsize=(20, 5))
plt.subplot(1, 4, 1)
plt.imshow(data.as_array())
plt.title('Ground Truth')
plt.colorbar()
plt.subplot(1, 4, 2)
plt.imshow(noisy_data.as_array())
plt.title('Noisy Data')
plt.colorbar()
plt.subplot(1, 4, 3)
plt.imshow(fi.get_output().as_array())
plt.title('FISTA Reconstruction')
plt.colorbar()
plt.subplot(1, 4, 4)
plt.plot(np.linspace(0, ig.shape[1], ig.shape[1]),
         data.as_array()[int(ig.shape[0] / 2), :],
         label='GTruth')
plt.plot(np.linspace(0, ig.shape[1], ig.shape[1]),
         fi.get_output().as_array()[int(ig.shape[0] / 2), :],
         label='TV reconstruction')
plt.legend()
plt.title('Middle Line Profiles')
plt.show()

#%% Use PDHG to solve non-smooth version of problem for comparison
コード例 #8
0
 def test_FISTA_cvx(self):
     
     if False:
         if not cvx_not_installable:
             try:
                 # Problem data.
                 m = 30
                 n = 20
                 np.random.seed(1)
                 Amat = np.random.randn(m, n)
                 A = LinearOperatorMatrix(Amat)
                 bmat = np.random.randn(m)
                 bmat.shape = (bmat.shape[0], 1)
 
                 # A = Identity()
                 # Change n to equal to m.
 
                 #b = DataContainer(bmat)
                 vg = VectorGeometry(m)
 
                 b = vg.allocate('random')
 
                 # Regularization parameter
                 lam = 10
                 opt = {'memopt': True}
                 # Create object instances with the test data A and b.
                 f = LeastSquares(A, b, c=0.5)
                 g0 = ZeroFunction()
 
                 # Initial guess
                 #x_init = DataContainer(np.zeros((n, 1)))
                 x_init = vg.allocate()
                 f.gradient(x_init, out = x_init)
 
                 # Run FISTA for least squares plus zero function.
                 #x_fista0, it0, timing0, criter0 = FISTA(x_init, f, g0, opt=opt)
                 fa = FISTA(x_init=x_init, f=f, g=g0)
                 fa.max_iteration = 10
                 fa.run(10)
                 
                 # Print solution and final objective/criterion value for comparison
                 print("FISTA least squares plus zero function solution and objective value:")
                 print(fa.get_output())
                 print(fa.get_last_objective())
 
                 # Compare to CVXPY
 
                 # Construct the problem.
                 x0 = Variable(n)
                 objective0 = Minimize(0.5*sum_squares(Amat*x0 - bmat.T[0]))
                 prob0 = Problem(objective0)
 
                 # The optimal objective is returned by prob.solve().
                 result0 = prob0.solve(verbose=False, solver=SCS, eps=1e-9)
 
                 # The optimal solution for x is stored in x.value and optimal objective value
                 # is in result as well as in objective.value
                 print("CVXPY least squares plus zero function solution and objective value:")
                 print(x0.value)
                 print(objective0.value)
                 self.assertNumpyArrayAlmostEqual(
                     numpy.squeeze(x_fista0.array), x0.value, 6)
             except SolverError as se:
                 print (str(se))
                 self.assertTrue(True)
         else:
             self.assertTrue(cvx_not_installable)
コード例 #9
0
    def stest_FISTA_Norm1_cvx(self):
        if not cvx_not_installable:
            try:
                opt = {'memopt': True}
                # Problem data.
                m = 30
                n = 20
                np.random.seed(1)
                Amat = np.random.randn(m, n)
                A = LinearOperatorMatrix(Amat)
                bmat = np.random.randn(m)
                #bmat.shape = (bmat.shape[0], 1)

                # A = Identity()
                # Change n to equal to m.
                vgb = VectorGeometry(m)
                vgx = VectorGeometry(n)
                b = vgb.allocate()
                b.fill(bmat)
                #b = DataContainer(bmat)

                # Regularization parameter
                lam = 10
                opt = {'memopt': True}
                # Create object instances with the test data A and b.
                f = LeastSquares(A, b, c=0.5)
                g0 = ZeroFunction()

                # Initial guess
                #x_init = DataContainer(np.zeros((n, 1)))
                x_init = vgx.allocate()

                # Create 1-norm object instance
                g1 = lam * L1Norm()

                g1(x_init)
                g1.prox(x_init, 0.02)

                # Combine with least squares and solve using generic FISTA implementation
                #x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g1, opt=opt)
                fa = FISTA(x_init=x_init, f=f, g=g1)
                fa.max_iteration = 10
                fa.run(10)
                

                # Print for comparison
                print("FISTA least squares plus 1-norm solution and objective value:")
                print(fa.get_output())
                print(fa.get_last_objective())

                # Compare to CVXPY

                # Construct the problem.
                x1 = Variable(n)
                objective1 = Minimize(
                    0.5*sum_squares(Amat*x1 - bmat.T[0]) + lam*norm(x1, 1))
                prob1 = Problem(objective1)

                # The optimal objective is returned by prob.solve().
                result1 = prob1.solve(verbose=False, solver=SCS, eps=1e-9)

                # The optimal solution for x is stored in x.value and optimal objective value
                # is in result as well as in objective.value
                print("CVXPY least squares plus 1-norm solution and objective value:")
                print(x1.value)
                print(objective1.value)

                self.assertNumpyArrayAlmostEqual(
                    numpy.squeeze(x_fista1.array), x1.value, 6)
            except SolverError as se:
                print (str(se))
                self.assertTrue(True)
        else:
            self.assertTrue(cvx_not_installable)
コード例 #10
0
fista.max_iteration = 500
fista.update_objective_interval = 100
fista.run(500, verbose=True)

#%% Show results

plt.figure(figsize=(10, 10))
plt.suptitle('Reconstructions ', fontsize=16)

plt.subplot(2, 2, 1)
plt.imshow(cgls.get_output().as_array())
plt.colorbar()
plt.title('CGLS reconstruction')

plt.subplot(2, 2, 2)
plt.imshow(fista.get_output().as_array())
plt.colorbar()
plt.title('FISTA reconstruction')

plt.subplot(2, 2, 3)
plt.imshow(pdhg.get_output().as_array())
plt.colorbar()
plt.title('PDHG reconstruction')

plt.subplot(2, 2, 4)
plt.imshow(recon_cgls_astra.as_array())
plt.colorbar()
plt.title('CGLS astra')

plt.show()