예제 #1
0
 def test_MFI_given_rawData(self):
     # Test correction given raw data
     or_corr_im = ORC.MFI(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj, df=field_map[:,:,0], Lx=1,
                             nonCart=1, params=acq_params)
     self.assertEqual(or_corr_im.shape, field_map.shape[:-1])  # Dimensions agree
예제 #2
0
 def test_CPR_wrongInputDimensions_rawdata(self):
     # raw data dimensions do not match ktraj dimensions
     data = np.ones((acq_params['N'], acq_params['N'] + 1))
     with self.assertRaises(ValueError): ORC.CPR(dataIn=data, dataInType='raw', kt=ktraj,
                                                 df=field_map_CPR)  # cartesian
예제 #3
0
 def test_fsCPR_given_im(self):
     # Test correction given image data
     or_corr_im = ORC.fs_CPR(dataIn=raw_data[:,:,0],dataInType='raw',kt=ktraj,df=field_map,Lx=1, nonCart=1, params=acq_params)
     or_corr_im2 = ORC.fs_CPR(dataIn=or_corr_im,dataInType='im',kt=ktraj,df=np.squeeze(field_map),Lx=1,nonCart=1,params=acq_params)
     self.assertEqual(or_corr_im2.shape, field_map.shape[:-1]) # Dimensions match
     self.assertEqual(or_corr_im2.all(), or_corr_im.all())
예제 #4
0
def numsim_cartesian():

    N = 192
    ph = resize(shepp_logan_phantom(), (N, N))

    ph = ph.astype(complex)
    plt.imshow(np.abs(ph), cmap='gray')
    plt.title('Original phantom')
    plt.axis('off')
    plt.colorbar()
    plt.show()

    brain_mask = mask_by_threshold(ph)

    # Floodfill from point (0, 0)
    ph_holes = ~(flood_fill(brain_mask, (0, 0), 1).astype(int)) + 2
    mask = brain_mask + ph_holes

    ##
    # Cartesian k-space trajectory
    ##
    dt = 10e-6  # grad raster time
    ktraj_cart = np.arange(0, N * dt, dt).reshape(1, N)
    ktraj_cart = np.tile(ktraj_cart, (N, 1))
    plt.imshow(ktraj_cart, cmap='gray')
    plt.title('Cartesian trajectory')
    plt.show()

    ##
    # Simulated field map
    ##
    fmax_v = [1600, 3200, 4800]  # Hz correspontig to 25, 50 and 75 ppm at 3T
    i = 0

    or_corrupted = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_CPR = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_fsCPR = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_MFI = np.zeros((N, N, len(fmax_v)), dtype='complex')
    for fmax in fmax_v:
        field_map = fieldmap_sim.realistic(np.abs(ph), mask, fmax)

        plt.imshow(field_map, cmap='gray')
        plt.title('Field Map')
        plt.colorbar()
        plt.axis('off')
        plt.show()

        ##
        # Corrupted images
        ##
        or_corrupted[:, :, i], _ = ORC.add_or_CPR(ph, ktraj_cart, field_map)
        corrupt = (np.abs(or_corrupted[:, :, i]) -
                   np.abs(or_corrupted[..., i]).min()) / (
                       np.abs(or_corrupted[:, :, i]).max() -
                       np.abs(or_corrupted[..., i]).min())
        #plt.imshow(np.abs(or_corrupted[:,:,i]),cmap='gray')
        plt.imshow(corrupt, cmap='gray')
        plt.colorbar()
        plt.title('Corrupted Image')
        plt.axis('off')
        plt.show()

        ###
        # Corrected images
        ###
        or_corrected_CPR[:, :, i] = ORC.CPR(or_corrupted[:, :, i], 'im',
                                            ktraj_cart, field_map)
        or_corrected_fsCPR[:, :, i] = ORC.fs_CPR(or_corrupted[:, :, i], 'im',
                                                 ktraj_cart, field_map, 2)
        or_corrected_MFI[:, :, i] = ORC.MFI(or_corrupted[:, :, i], 'im',
                                            ktraj_cart, field_map, 2)
        i += 1


##
# Plot
##
    im_stack = np.stack(
        (np.squeeze(or_corrupted), np.squeeze(or_corrected_CPR),
         np.squeeze(or_corrected_fsCPR), np.squeeze(or_corrected_MFI)))
    cols = ('Corrupted Image', 'CPR Correction', 'fs-CPR Correction',
            'MFI Correction')
    row_names = ('-/+ 1600 Hz', '-/+ 3200 Hz', '-/+ 4800 Hz')
    plot_correction_results(im_stack, cols, row_names)
예제 #5
0
def numsim_epi():
    ##
    # Original image: Shep-Logan Phantom
    ##
    N = 128
    FOV = 256e-3
    ph = resize(shepp_logan_phantom(), (N, N))  #.astype(complex)
    plt.imshow(np.abs(ph), cmap='gray')
    plt.title('Original phantom')
    plt.axis('off')
    plt.colorbar()
    plt.show()

    brain_mask = mask_by_threshold(ph)

    # Floodfill from point (0, 0)
    ph_holes = ~(flood_fill(brain_mask, (0, 0), 1).astype(int)) + 2
    mask = brain_mask + ph_holes

    ##
    # EPI k-space trajectory
    ##
    dt = 4e-6
    ktraj = ssEPI_2d(N, FOV)  # k-space trajectory

    plt.plot(ktraj[:, 0], ktraj[:, 1])
    plt.title('EPI trajectory')
    plt.show()

    Ta = ktraj.shape[0] * dt
    T = (np.arange(ktraj.shape[0]) * dt).reshape(ktraj.shape[0], 1)
    seq_params = {
        'N': N,
        'Npoints': ktraj.shape[0],
        'Nshots': 1,
        't_readout': Ta,
        't_vector': T
    }

    ##
    # Simulated field map
    ##
    # fmax_v = [50, 75, 100]  # Hz
    fmax_v = [100, 150, 200]

    i = 0
    or_corrupted = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_CPR = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_fsCPR = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_MFI = np.zeros((N, N, len(fmax_v)), dtype='complex')
    for fmax in fmax_v:
        # field_map = fieldmap_sim.realistic(np.abs(ph), mask, fmax)

        SL_smooth = gaussian(ph, sigma=3)
        field_map = cv2.normalize(SL_smooth, None, -fmax, fmax,
                                  cv2.NORM_MINMAX)
        field_map = np.round(field_map * mask)
        field_map[np.where(field_map == -0.0)] = 0

        plt.imshow(field_map, cmap='gray')
        plt.title('Field Map +/-' + str(fmax) + ' Hz')
        plt.colorbar()
        plt.axis('off')
        plt.show()

        ##
        # Corrupted images
        ##
        or_corrupted[:, :,
                     i], EPI_ksp = ORC.add_or_CPR(ph, ktraj, field_map, 'EPI',
                                                  seq_params)
        corrupt = (np.abs(or_corrupted[:, :, i]) -
                   np.abs(or_corrupted[..., i]).min()) / (
                       np.abs(or_corrupted[:, :, i]).max() -
                       np.abs(or_corrupted[..., i]).min())
        # plt.imshow(np.abs(or_corrupted[:,:,i]),cmap='gray')
        plt.imshow(corrupt, cmap='gray')
        plt.colorbar()
        plt.title('Corrupted Image')
        plt.axis('off')
        plt.show()

        ###
        # Corrected images
        ###
        or_corrected_CPR[:, :,
                         i] = ORC.correct_from_kdat('CPR', EPI_ksp, ktraj,
                                                    field_map, seq_params,
                                                    'EPI')
        or_corrected_fsCPR[:, :,
                           i] = ORC.correct_from_kdat('fsCPR', EPI_ksp, ktraj,
                                                      field_map, seq_params,
                                                      'EPI')
        or_corrected_MFI[:, :,
                         i] = ORC.correct_from_kdat('MFI', EPI_ksp, ktraj,
                                                    field_map, seq_params,
                                                    'EPI')

        # or_corrected_CPR[:, :, i] = ORC.CPR(or_corrupted[:, :, i], 'im', ktraj, field_map, 'EPI', seq_params)
        #
        # or_corrected_fsCPR[:, :, i] = ORC.fs_CPR(or_corrupted[:, :, i], 'im', ktraj, field_map, 2, 'EPI', seq_params)
        #
        # or_corrected_MFI[:, :, i] = ORC.MFI(or_corrupted[:, :, i], 'im', ktraj, field_map, 2, 'EPI', seq_params)

        i += 1

    ##
    # Plot
    ##
    im_stack = np.stack(
        (np.squeeze(or_corrupted), np.squeeze(or_corrected_CPR),
         np.squeeze(or_corrected_fsCPR), np.squeeze(or_corrected_MFI)))
    # np.save('im_stack.npy',im_stack)
    cols = ('Corrupted Image', 'CPR Correction', 'fs-CPR Correction',
            'MFI Correction')
    row_names = ('-/+ 100 Hz', '-/+ 150 Hz', '-/+ 200 Hz')
    plot_correction_results(im_stack, cols, row_names)
예제 #6
0
def numsim_spiral():
    N = 128  # ph.shape[0]
    ph = resize(shepp_logan_phantom(), (N, N))  #.astype(complex)
    plt.imshow(np.abs(ph), cmap='gray')
    plt.title('Original phantom')
    plt.axis('off')
    plt.colorbar()
    plt.show()

    brain_mask = mask_by_threshold(ph)

    # Floodfill from point (0, 0)
    ph_holes = ~(flood_fill(brain_mask, (0, 0), 1).astype(int)) + 2
    mask = brain_mask + ph_holes

    ##
    # Spiral k-space trajectory
    ##
    dt = 4e-6
    ktraj = np.load('sample_data/SS_sprial_ktraj.npy')  # k-space trajectory

    plt.plot(ktraj.real, ktraj.imag)
    plt.title('Spiral trajectory')
    plt.show()

    #ktraj_dcf = np.load('test_data/ktraj_noncart_dcf.npy').flatten() # density compensation factor
    t_ro = ktraj.shape[0] * dt
    T = (np.arange(ktraj.shape[0]) * dt).reshape(ktraj.shape[0], 1)

    seq_params = {
        'N': ph.shape[0],
        'Npoints': ktraj.shape[0],
        'Nshots': ktraj.shape[1],
        't_readout': t_ro,
        't_vector': T
    }  #, 'dcf': ktraj_dcf}
    ##
    # Simulated field map
    ##
    fmax_v = [100, 150, 200]  # Hz

    i = 0
    or_corrupted = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_CPR = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_fsCPR = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_MFI = np.zeros((N, N, len(fmax_v)), dtype='complex')
    for fmax in fmax_v:

        SL_smooth = gaussian(ph, sigma=3)
        field_map = cv2.normalize(SL_smooth, None, -fmax, fmax,
                                  cv2.NORM_MINMAX)
        field_map = np.round(field_map * mask)
        field_map[np.where(field_map == -0.0)] = 0
        ###
        plt.imshow(field_map, cmap='gray')
        plt.title('Field Map +/-' + str(fmax) + ' Hz')
        plt.colorbar()
        plt.axis('off')
        plt.show()

        ##
        # Corrupted images
        ##
        or_corrupted[:, :, i], ksp_corrupted = ORC.add_or_CPR(
            ph, ktraj, field_map, 1, seq_params)
        corrupt = (np.abs(or_corrupted[:, :, i]) -
                   np.abs(or_corrupted[..., i]).min()) / (
                       np.abs(or_corrupted[:, :, i]).max() -
                       np.abs(or_corrupted[..., i]).min())
        #plt.imshow(np.abs(or_corrupted[:,:,i]),cmap='gray')
        plt.imshow(corrupt, cmap='gray')
        plt.colorbar()
        plt.title('Corrupted Image')
        plt.axis('off')
        plt.show()

        ###
        # Corrected images
        ###
        or_corrected_CPR[:, :,
                         i] = ORC.correct_from_kdat('CPR', ksp_corrupted,
                                                    ktraj, field_map,
                                                    seq_params, 1)
        or_corrected_fsCPR[:, :,
                           i] = ORC.correct_from_kdat('fsCPR', ksp_corrupted,
                                                      ktraj, field_map,
                                                      seq_params, 1)
        or_corrected_MFI[:, :,
                         i] = ORC.correct_from_kdat('MFI', ksp_corrupted,
                                                    ktraj, field_map,
                                                    seq_params, 1)
        # or_corrected_CPR[:, :, i] = ORC.CPR(or_corrupted[:, :, i], 'im', ktraj, field_map, 1, seq_params)
        #
        # or_corrected_fsCPR[:, :, i] = ORC.fs_CPR(or_corrupted[:, :, i], 'im', ktraj, field_map, 2, 1, seq_params)
        #
        # or_corrected_MFI[:,:,i] = ORC.MFI(or_corrupted[:,:,i], 'im', ktraj, field_map, 2, 1, seq_params)

        i += 1

    ##
    # Plot
    ##
    im_stack = np.stack(
        (np.squeeze(or_corrupted), np.squeeze(or_corrected_CPR),
         np.squeeze(or_corrected_fsCPR), np.squeeze(or_corrected_MFI)))
    #np.save('im_stack.npy',im_stack)
    cols = ('Corrupted Image', 'CPR Correction', 'fs-CPR Correction',
            'MFI Correction')
    row_names = ('-/+ 100 Hz', '-/+ 150 Hz', '-/+ 200 Hz')
    plot_correction_results(im_stack, cols, row_names)
예제 #7
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("path", help="Path containing the input data")
    parser.add_argument(
        "datain",
        help=
        "Corrupted data with dimensions [Npoints OR N(Cartesian), Nshots OR N(Cartesian), Nslices, Nchannels]"
    )
    parser.add_argument(
        "ktraj",
        help=
        "K-space trajectory: [Npoints OR N(Cartesian), Nshots OR N(Cartesian)]"
    )
    parser.add_argument("fmap", help="Field map in rad/s")
    parser.add_argument("ORCmethod",
                        choices=['CPR', 'fsCPR', 'MFI'],
                        help="Correction method")
    parser.add_argument("--cart",
                        default=0,
                        choices=[0, 1],
                        dest="cart_opt",
                        help="0: non-cartesian data, 1: cartesian data")
    parser.add_argument(
        '--Lx',
        default=2,
        dest='Lx',
        help="L(frequency bins) factor with respect to minimum L")
    parser.add_argument('--grad_raster',
                        default=10e-6,
                        dest='grad_raster',
                        help="Gradient raster time")
    parser.add_argument('--TE', default=0, dest='TE', help="Echo time")
    parser.add_argument(
        '--dcf',
        default=None,
        dest='dcf',
        help="Density compensation factor for non-cartesian trajectories")

    args = parser.parse_args()
    data_in = get_data_from_file(os.path.join(args.path, args.datain))
    ktraj = get_data_from_file(os.path.join(args.path, args.ktraj))
    fmap = get_data_from_file(os.path.join(args.path,
                                           args.fmap)) / (2 * math.pi)

    ## Dimensions check
    print('Checking dimensions...')
    if data_in.shape[0] != ktraj.shape[0] or data_in.shape[1] != ktraj.shape[1]:
        raise ValueError(
            'The raw data does not agree with the k-space trajectory')
    if fmap.shape[0] != fmap.shape[1]:
        raise ValueError(
            'Image and field map should have square dimensions (NxN)')
    if len(fmap.shape) > 2:
        if fmap.shape[-1] != data_in.shape[2]:
            raise ValueError(
                'The field map dimensions do not agree with the raw data')
    if len(fmap.shape) == 2:
        fmap = np.expand_dims(fmap, -1)
        if data_in.shape[2] != 1:
            raise ValueError(
                'The field map dimensions do not agree with the raw data')

    print('OK')

    ## Sequence parameters
    N = fmap.shape[0]
    Npoints = data_in.shape[0]
    Nshots = data_in.shape[1]
    Nslices = data_in.shape[-2]
    Nchannels = data_in.shape[-1]

    t_ro = Npoints * args.grad_raster
    T = np.linspace(args.TE, args.TE + t_ro, Npoints).reshape((Npoints, 1))
    if args.cart_opt == 0:
        seq_params = {
            'N': N,
            'Npoints': Npoints,
            'Nshots': Nshots,
            't_vector': T,
            't_readout': t_ro
        }

        if args.dcf is not None:
            dcf = get_data_from_file(os.path.join(args.path,
                                                  args.dcf)).flatten()
            seq_params.update({'dcf': dcf})

    ORC_result = np.zeros((N, N, Nslices, Nchannels), dtype=complex)

    if args.ORCmethod == 'MFI':
        for ch in tqdm(range(Nchannels)):
            for sl in range(Nslices):
                if args.cart_opt == 0:
                    ORC_result[:, :, sl,
                               ch] = ORC.MFI(np.squeeze(data_in[:, :, sl, ch]),
                                             'raw', ktraj,
                                             np.squeeze(fmap[:, :, sl]),
                                             args.Lx, 1, seq_params)
                elif args.cart_opt == 1:
                    ORC_result[:, :, sl,
                               ch] = ORC.MFI(np.squeeze(data_in[:, :, sl, ch]),
                                             'raw', ktraj,
                                             np.squeeze(fmap[:, :,
                                                             sl]), args.Lx)

    elif args.ORCmethod == 'fsCPR':
        for ch in tqdm(range(Nchannels)):
            for sl in range(Nslices):
                if args.cart == 0:
                    ORC_result[:, :, sl, ch] = ORC.fs_CPR(
                        np.squeeze(data_in[:, :, sl, ch]), 'raw', ktraj,
                        np.squeeze(fmap[:, :, sl]), args.Lx, 1, seq_params)
                elif args.cart == 1:
                    ORC_result[:, :, sl, ch] = ORC.fs_CPR(
                        np.squeeze(data_in[:, :, sl, ch]), 'raw', ktraj,
                        np.squeeze(fmap[:, :, sl]), args.Lx)
    elif args.ORCmethod == 'CPR':
        for ch in tqdm(range(Nchannels)):
            for sl in range(Nslices):
                if args.cart == 0:
                    ORC_result[:, :, sl,
                               ch] = ORC.CPR(np.squeeze(data_in[:, :, sl, ch]),
                                             'raw', ktraj,
                                             np.squeeze(fmap[:, :, sl]), 1,
                                             seq_params)
                elif args.cart == 1:
                    ORC_result[:, :, sl,
                               ch] = ORC.CPR(np.squeeze(data_in[:, :, sl, ch]),
                                             'raw', ktraj,
                                             np.squeeze(fmap[:, :, sl]))

    np.save(os.path.join(args.path, 'ORC_result_' + args.ORCmethod),
            ORC_result)
예제 #8
0
Lx = 2  # Frequency segments for fs-CPR and MFI. L = Lx * Lmin
if Lx < 1:
    raise ValueError('The L factor cannot be lower that 1 (minimum L)')

CPR_result = np.zeros((N, N, Nslices, Nchannels), dtype=complex)
fsCPR_result = np.zeros((N, N, Nslices, Nchannels), dtype=complex)
MFI_result = np.zeros((N, N, Nslices, Nchannels), dtype=complex)
CPR_timing = 0
fsCPR_timing = 0
MFI_timing = 0
for ch in range(Nchannels):
    for sl in range(Nslices):
        before = time.time()
        CPR_result[:, :, sl, ch] = ORC.CPR(np.squeeze(rawdata[:, :, sl, ch]),
                                           'raw', ktraj,
                                           np.squeeze(fmap[:, :,
                                                           sl]), 1, seq_params)
        CPR_timing += time.time() - before

        print('CPR: Done with slice:' + str(sl + 1) + ', channel:' +
              str(ch + 1))
        np.save(outputs['path_correction_folder'] + 'CPR', CPR_result)

        before = time.time()
        fsCPR_result[:, :, sl,
                     ch] = ORC.fs_CPR(np.squeeze(rawdata[:, :, sl,
                                                         ch]), 'raw', ktraj,
                                      np.squeeze(fmap[:, :, sl]), Lx, 1,
                                      seq_params)
        fsCPR_timing += time.time() - before