Exemple #1
0
# W_x = NufftObj.xx2k( NufftObj.adjoint(NufftObj.forward(NufftObj.k2xx(W0))))
# W_y =  NufftObj.xx2k(NufftObj.x2xx(NufftObj.adjoint(NufftObj.k2y(W0))))
W = NufftObj.xx2k(NufftObj.adjoint(W0))

# W =   NufftObj.y2k(W0)
# matplotlib.pyplot.subplot(1,)
matplotlib.pyplot.imshow(numpy.real((W * W.conj())**0.5))
matplotlib.pyplot.title('Ueckers inverse function (real)')
# matplotlib.pyplot.subplot(1,2,2)
# matplotlib.pyplot.imshow(W.imag)
# matplotlib.pyplot.title('Ueckers inverse function (imaginary)')
matplotlib.pyplot.show()

p0 = NufftObj.adjoint(NufftObj.forward(image))
p1 = NufftObj.k2xx((W.conj() * W)**0.5 * NufftObj.xx2k(image))

print('error between Toeplitz and Inverse reconstruction',
      numpy.linalg.norm(p1 - p0) / numpy.linalg.norm(p0))

matplotlib.pyplot.subplot(1, 3, 1)
matplotlib.pyplot.imshow(numpy.real(p0))
matplotlib.pyplot.title('Toeplitz')
matplotlib.pyplot.subplot(1, 3, 2)
matplotlib.pyplot.imshow(numpy.real(p1))
matplotlib.pyplot.title('Ueckers inverse function')
matplotlib.pyplot.subplot(1, 3, 3)
matplotlib.pyplot.imshow(numpy.abs(p0 - p1) / numpy.abs(p1))
matplotlib.pyplot.title('Difference')
matplotlib.pyplot.show()
Exemple #2
0
def test_2D():
    import pkg_resources

    DATA_PATH = pkg_resources.resource_filename('pynufft', 'src/data/')
    #     PHANTOM_FILE = pkg_resources.resource_filename('pynufft', 'data/phantom_256_256.txt')
    import numpy
    import matplotlib.pyplot
    from pynufft import NUFFT_cpu
    # load example image
    #     image = numpy.loadtxt(DATA_PATH +'phantom_256_256.txt')
    image = scipy.misc.ascent()

    image = scipy.misc.imresize(image, (256, 256))

    image = image.astype(numpy.float) / numpy.max(image[...])
    #numpy.save('phantom_256_256',image)
    matplotlib.pyplot.imshow(image, cmap=matplotlib.cm.gray)
    matplotlib.pyplot.show()
    print('loading image...')

    Nd = (256, 256)  # image size
    print('setting image dimension Nd...', Nd)
    Kd = (512, 512)  # k-space size
    print('setting spectrum dimension Kd...', Kd)
    Jd = (6, 6)  # interpolation size
    print('setting interpolation size Jd...', Jd)
    # load k-space points
    # om = numpy.loadtxt(DATA_PATH+'om.txt')
    om = numpy.load(DATA_PATH + 'om2D.npz')['arr_0']
    print('setting non-uniform coordinates...')
    matplotlib.pyplot.plot(om[::10, 0], om[::10, 1], 'o')
    matplotlib.pyplot.title('non-uniform coordinates')
    matplotlib.pyplot.xlabel('axis 0')
    matplotlib.pyplot.ylabel('axis 1')
    matplotlib.pyplot.show()

    NufftObj = NUFFT_cpu()
    NufftObj.plan(om, Nd, Kd, Jd)

    y = NufftObj.forward(image)

    print('setting non-uniform data')
    print('y is an (M,) list', type(y), y.shape)

    W = numpy.ones(Kd, dtype=numpy.complex64)
    for pp in range(0, 200):
        W2 = NufftObj.xx2k(NufftObj.adjoint(NufftObj.forward(
            NufftObj.k2xx(W))))
        W2 = W2 * W2.conj()
        W2 = W2**0.5
        W = (W + 0.9) / (W2 + 0.9)

    matplotlib.pyplot.subplot(1, 2, 1)
    matplotlib.pyplot.imshow(W2.real)

    matplotlib.pyplot.subplot(1, 2, 2)
    matplotlib.pyplot.imshow((W / W2).real)
    matplotlib.pyplot.show()

    #     kspectrum = NufftObj.xx2k( NufftObj.solve(y,solver='bicgstab',maxiter = 100))
    image_restore = NufftObj.solve(y, solver='cg', maxiter=10)
    shifted_kspectrum = numpy.fft.fftshift(
        numpy.fft.fftn(numpy.fft.fftshift(image_restore)))
    print('getting the k-space spectrum, shape =', shifted_kspectrum.shape)
    print('Showing the shifted k-space spectrum')

    matplotlib.pyplot.imshow(shifted_kspectrum.real,
                             cmap=matplotlib.cm.gray,
                             norm=matplotlib.colors.Normalize(vmin=-100,
                                                              vmax=100))
    matplotlib.pyplot.title('shifted k-space spectrum')
    matplotlib.pyplot.show()
    image2 = NufftObj.solve(y, 'dc', maxiter=25)
    #     image3 = NufftObj.solve(y, 'L1TVLAD',maxiter=100, rho= 1)
    image3 = NufftObj.k2xx(NufftObj.xx2k(NufftObj.adjoint(y)) * W)
    print(image3.shape)
    image4 = NufftObj.solve(y, 'L1TVOLS', maxiter=100, rho=1)
    matplotlib.pyplot.subplot(1, 3, 1)
    matplotlib.pyplot.imshow(image,
                             cmap=matplotlib.cm.gray,
                             norm=matplotlib.colors.Normalize(vmin=0.0,
                                                              vmax=1))
    matplotlib.pyplot.subplot(1, 3, 2)
    matplotlib.pyplot.imshow(image3.real,
                             cmap=matplotlib.cm.gray,
                             norm=matplotlib.colors.Normalize(vmin=0.0,
                                                              vmax=1))
    matplotlib.pyplot.subplot(1, 3, 3)
    matplotlib.pyplot.imshow(image4.real,
                             cmap=matplotlib.cm.gray,
                             norm=matplotlib.colors.Normalize(vmin=0.0,
                                                              vmax=1))
    matplotlib.pyplot.show()

    #     matplotlib.pyplot.imshow(image2.real, cmap=matplotlib.cm.gray, norm=matplotlib.colors.Normalize(vmin=0.0, vmax=1))
    #     matplotlib.pyplot.show()
    maxiter = 25
    counter = 1
    for solver in ('dc', 'bicg', 'bicgstab', 'cg', 'gmres', 'lgmres', 'lsmr',
                   'lsqr'):
        print(counter, solver)
        if 'lsqr' == solver:
            image2 = NufftObj.solve(y, solver, iter_lim=maxiter)
        else:
            image2 = NufftObj.solve(y, solver, maxiter=maxiter)
#     image2 = NufftObj.solve(y, solver='bicgstab',maxiter=30)
        matplotlib.pyplot.subplot(2, 4, counter)
        matplotlib.pyplot.imshow(image2.real,
                                 cmap=matplotlib.cm.gray,
                                 norm=matplotlib.colors.Normalize(vmin=0.0,
                                                                  vmax=1))
        matplotlib.pyplot.title(solver)
        #         print(counter, solver)
        counter += 1
    matplotlib.pyplot.show()