예제 #1
0
    def test_reconstruction_fista_fft2(self):
        """ Test all the registered transformations.
        """
        print("Process test FFT2 FISTA::")
        for image in self.images:
            fourier = FFT2(samples=convert_mask_to_locations(
                fftshift(self.mask)),
                           shape=image.shape)

            data = fourier.op(image.data)
            fourier_op = FFT2(convert_mask_to_locations(fftshift(self.mask)),
                              shape=image.shape)

            print("Process test with image '{0}'...".format(
                image.metadata["path"]))
            for nb_scale in self.nb_scales:
                print("- Number of scales: {0}".format(nb_scale))
                for name in self.names:
                    print("    Transform: {0}".format(name))
                    linear_op = WaveletN(wavelet_name=name, nb_scale=4)
                    gradient_op = GradSynthesis2(data=data,
                                                 fourier_op=fourier_op,
                                                 linear_op=linear_op)
                    prox_op = Threshold(0)
                    x_final, transform, _, _ = sparse_rec_fista(
                        gradient_op=gradient_op,
                        linear_op=linear_op,
                        prox_op=prox_op,
                        cost_op=None,
                        lambda_init=1.0,
                        max_nb_of_iter=self.nb_iter,
                        verbose=0)
                    fourier_0 = FFT2(samples=convert_mask_to_locations(
                        fftshift(self.mask)),
                                     shape=image.shape)
                    data_0 = fourier_0.op(numpy.fft.fftshift(image.data))
                    self.assertTrue(
                        numpy.allclose(x_final.any(),
                                       numpy.fft.ifftshift(
                                           fourier_0.adj_op(data_0)).any(),
                                       rtol=1e-10))
                    mean_square_error = numpy.mean(
                        numpy.abs(
                            x_final -
                            numpy.fft.ifftshift(fourier_0.adj_op(data_0)))**2)
                    print("      Mean Square Error = ", mean_square_error)
def reco_wav(kspace,
             gradient_op,
             mu=1 * 1e-8,
             max_iter=10,
             nb_scales=4,
             wavelet_name='db4'):
    # for now this is only working with my fork of pysap-fastMRI
    # I will get it changed soon so that we don't need to ask for a specific
    # pysap-mri install
    from ..wavelets import WaveletDecimated
    from mri.numerics.reconstruct import sparse_rec_fista

    linear_op = WaveletDecimated(
        nb_scale=nb_scales,
        wavelet_name=wavelet_name,
        padding='periodization',
    )

    prox_op = LinearCompositionProx(
        linear_op=linear_op,
        prox_op=SparseThreshold(Identity(), None, thresh_type="soft"),
    )
    gradient_op.obs_data = kspace
    cost_op = None
    x_final, _, _, _ = sparse_rec_fista(
        gradient_op=gradient_op,
        linear_op=Identity(),
        prox_op=prox_op,
        cost_op=cost_op,
        xi_restart=0.96,
        s_greedy=1.1,
        mu=mu,
        restart_strategy='greedy',
        pov='analysis',
        max_nb_of_iter=max_iter,
        metrics=None,
        metric_call_period=1,
        verbose=0,
        progress=False,
    )
    x_final = np.abs(x_final)
    x_final = crop_center(x_final, 320)
    return x_final
예제 #3
0
gradient_op, linear_op, prox_op, cost_op = generate_operators(
    data=kspace_obs,
    wavelet_name="BsplineWaveletTransformATrousAlgorithm",
    samples=kspace_loc,
    nb_scales=4,
    non_cartesian=True,
    uniform_data_shape=image.shape,
    gradient_space="synthesis")

# Start the FISTA reconstruction
max_iter = 20
x_final, transform, costs, metrics = sparse_rec_fista(gradient_op,
                                                      linear_op,
                                                      prox_op,
                                                      None,
                                                      mu=1e-9,
                                                      lambda_init=1.0,
                                                      max_nb_of_iter=max_iter,
                                                      atol=1e-4,
                                                      verbose=1)
image_rec = pysap.Image(data=np.abs(x_final))
image_rec.show()

#############################################################################
# Condata-Vu optimization
# -----------------------
#
# We now want to refine the zero order solution using a Condata-Vu
# optimization.
# Here no cost function is set, and the optimization will reach the
# maximum number of iterations. Fill free to play with this parameter.
예제 #4
0
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# Generate operators
gradient_op, linear_op, prox_op, cost_op = generate_operators(
    data=kspace_obs,
    wavelet_name="sym8",
    samples=kspace_loc,
    mu=6 * 1e-7,
    nb_scales=4,
    non_cartesian=True,
    uniform_data_shape=image.shape,
    gradient_space="synthesis")

# Start the FISTA reconstruction
max_iter = 200
x_final, costs, metrics = sparse_rec_fista(gradient_op=gradient_op,
                                           linear_op=linear_op,
                                           prox_op=prox_op,
                                           cost_op=cost_op,
                                           lambda_init=1.0,
                                           max_nb_of_iter=max_iter,
                                           atol=1e-4,
                                           verbose=1)
image_rec = pysap.Image(data=np.abs(x_final))
# image_rec.show()
recon_ssim = ssim(image_rec, image)
print('The Reconstruction SSIM is : ' + str(recon_ssim))