Ejemplo n.º 1
0
    def test_FISTA_Denoising(self):
        if debug_print:
            print("FISTA Denoising Poisson Noise Tikhonov")
        # adapted from demo FISTA_Tikhonov_Poisson_Denoising.py in CIL-Demos repository
        data = dataexample.SHAPES.get()
        ig = data.geometry
        ag = ig
        N = 300
        # Create Noisy data with Poisson noise
        scale = 5
        noisy_data = applynoise.poisson(data / scale, seed=10) * scale

        # Regularisation Parameter
        alpha = 10

        # Setup and run the FISTA algorithm
        operator = GradientOperator(ig)
        fid = KullbackLeibler(b=noisy_data)
        reg = OperatorCompositionFunction(alpha * L2NormSquared(), operator)

        initial = ig.allocate()
        fista = FISTA(initial=initial, f=reg, g=fid)
        fista.max_iteration = 3000
        fista.update_objective_interval = 500
        fista.run(verbose=0)
        rmse = (fista.get_output() - data).norm() / data.as_array().size
        if debug_print:
            print("RMSE", rmse)
        self.assertLess(rmse, 4.2e-4)
Ejemplo n.º 2
0
    def test_FISTA_Norm2Sq(self):
        print("Test FISTA Norm2Sq")
        ig = ImageGeometry(127, 139, 149)
        b = ig.allocate(ImageGeometry.RANDOM)
        # fill with random numbers
        initial = ig.allocate(ImageGeometry.RANDOM)
        identity = IdentityOperator(ig)

        #### it seems FISTA does not work with Nowm2Sq
        norm2sq = LeastSquares(identity, b)
        #norm2sq.L = 2 * norm2sq.c * identity.norm()**2
        #norm2sq = OperatorCompositionFunction(L2NormSquared(b=b), identity)
        opt = {'tol': 1e-4, 'memopt': False}
        if debug_print:
            print("initial objective", norm2sq(initial))
        alg = FISTA(initial=initial, f=norm2sq, g=ZeroFunction())
        alg.max_iteration = 2
        alg.run(20, verbose=0)
        self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())

        alg = FISTA(initial=initial,
                    f=norm2sq,
                    g=ZeroFunction(),
                    max_iteration=2,
                    update_objective_interval=3)
        self.assertTrue(alg.max_iteration == 2)
        self.assertTrue(alg.update_objective_interval == 3)

        alg.run(20, verbose=0)
        self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())