Example #1
0
 def test_MFI_given_im(self):
     # Test correction given image data
     or_corr_im = ORC.MFI(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj, df=field_map[:,:,0], Lx=1,
                             nonCart=1, params=acq_params)
     or_corr_im2 = ORC.MFI(dataIn=or_corr_im, dataInType='im', kt=ktraj, df=field_map[:,:,0], Lx=1, nonCart=1, params=acq_params)
     self.assertEqual(or_corr_im2.shape, field_map.shape[:-1])
     self.assertEqual(or_corr_im2.all(), or_corr_im.all())
Example #2
0
 def test_MFI_wrongtype(self):
     # Error when dataIn is raw data but dataInType is 'im'
     with self.assertRaises(ValueError): ORC.MFI(dataIn=raw_data[:, :, 0], dataInType='im', kt=ktraj,
                                                    df=field_map, Lx=1, nonCart=1, params=acq_params)
     or_corr_im = ORC.MFI(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj, df=field_map[:,:,0], Lx=1, nonCart=1,
                             params=acq_params)
     # Error when dataIn is image data but dataInType is 'raw'
     with self.assertRaises(ValueError): ORC.MFI(dataIn=or_corr_im, dataInType='raw', kt=ktraj,
                                                    df=field_map, Lx=1, nonCart=1, params=acq_params)
     # Error when dataInType is other than 'raw' or 'im'
     with self.assertRaises(ValueError): ORC.MFI(dataIn=raw_data[:, :, 0], dataInType='other', kt=ktraj,
                                                    df=field_map, Lx=1, nonCart=1, params=acq_params)
Example #3
0
 def test_Npoints(self):
     params_dict = acq_params.copy()
     params_dict['Npoints'] = 10
     with self.assertRaises(ValueError): ORC.CPR(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj,
                                                    df=field_map_CPR, nonCart=1, params=params_dict)
     with self.assertRaises(ValueError): ORC.fs_CPR(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj,
                                                    df=field_map_CPR, Lx=1, nonCart=1, params=params_dict)
     with self.assertRaises(ValueError): ORC.MFI(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj,
                                                    df=field_map_CPR, Lx=1, nonCart=1, params=params_dict)
Example #4
0
 def test_N(self):
     # Test that an error is raised if N specified in the parameters dictionary does not match the image dimensions
     params_dict = acq_params.copy()
     params_dict['N'] = 1
     with self.assertRaises(ValueError): ORC.CPR(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj,
                                                    df=field_map_CPR, nonCart=1, params=params_dict)
     with self.assertRaises(ValueError): ORC.fs_CPR(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj,
                                                    df=field_map_CPR, Lx=1, nonCart=1, params=params_dict)
     with self.assertRaises(ValueError): ORC.MFI(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj,
                                                    df=field_map_CPR, Lx=1, nonCart=1, params=params_dict)
Example #5
0
 def test_MFI_given_rawData(self):
     # Test correction given raw data
     or_corr_im = ORC.MFI(dataIn=raw_data[:, :, 0], dataInType='raw', kt=ktraj, df=field_map[:,:,0], Lx=1,
                             nonCart=1, params=acq_params)
     self.assertEqual(or_corr_im.shape, field_map.shape[:-1])  # Dimensions agree
def numsim_cartesian():

    N = 192
    ph = resize(shepp_logan_phantom(), (N, N))

    ph = ph.astype(complex)
    plt.imshow(np.abs(ph), cmap='gray')
    plt.title('Original phantom')
    plt.axis('off')
    plt.colorbar()
    plt.show()

    brain_mask = mask_by_threshold(ph)

    # Floodfill from point (0, 0)
    ph_holes = ~(flood_fill(brain_mask, (0, 0), 1).astype(int)) + 2
    mask = brain_mask + ph_holes

    ##
    # Cartesian k-space trajectory
    ##
    dt = 10e-6  # grad raster time
    ktraj_cart = np.arange(0, N * dt, dt).reshape(1, N)
    ktraj_cart = np.tile(ktraj_cart, (N, 1))
    plt.imshow(ktraj_cart, cmap='gray')
    plt.title('Cartesian trajectory')
    plt.show()

    ##
    # Simulated field map
    ##
    fmax_v = [1600, 3200, 4800]  # Hz correspontig to 25, 50 and 75 ppm at 3T
    i = 0

    or_corrupted = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_CPR = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_fsCPR = np.zeros((N, N, len(fmax_v)), dtype='complex')
    or_corrected_MFI = np.zeros((N, N, len(fmax_v)), dtype='complex')
    for fmax in fmax_v:
        field_map = fieldmap_sim.realistic(np.abs(ph), mask, fmax)

        plt.imshow(field_map, cmap='gray')
        plt.title('Field Map')
        plt.colorbar()
        plt.axis('off')
        plt.show()

        ##
        # Corrupted images
        ##
        or_corrupted[:, :, i], _ = ORC.add_or_CPR(ph, ktraj_cart, field_map)
        corrupt = (np.abs(or_corrupted[:, :, i]) -
                   np.abs(or_corrupted[..., i]).min()) / (
                       np.abs(or_corrupted[:, :, i]).max() -
                       np.abs(or_corrupted[..., i]).min())
        #plt.imshow(np.abs(or_corrupted[:,:,i]),cmap='gray')
        plt.imshow(corrupt, cmap='gray')
        plt.colorbar()
        plt.title('Corrupted Image')
        plt.axis('off')
        plt.show()

        ###
        # Corrected images
        ###
        or_corrected_CPR[:, :, i] = ORC.CPR(or_corrupted[:, :, i], 'im',
                                            ktraj_cart, field_map)
        or_corrected_fsCPR[:, :, i] = ORC.fs_CPR(or_corrupted[:, :, i], 'im',
                                                 ktraj_cart, field_map, 2)
        or_corrected_MFI[:, :, i] = ORC.MFI(or_corrupted[:, :, i], 'im',
                                            ktraj_cart, field_map, 2)
        i += 1


##
# Plot
##
    im_stack = np.stack(
        (np.squeeze(or_corrupted), np.squeeze(or_corrected_CPR),
         np.squeeze(or_corrected_fsCPR), np.squeeze(or_corrected_MFI)))
    cols = ('Corrupted Image', 'CPR Correction', 'fs-CPR Correction',
            'MFI Correction')
    row_names = ('-/+ 1600 Hz', '-/+ 3200 Hz', '-/+ 4800 Hz')
    plot_correction_results(im_stack, cols, row_names)
Example #7
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("path", help="Path containing the input data")
    parser.add_argument(
        "datain",
        help=
        "Corrupted data with dimensions [Npoints OR N(Cartesian), Nshots OR N(Cartesian), Nslices, Nchannels]"
    )
    parser.add_argument(
        "ktraj",
        help=
        "K-space trajectory: [Npoints OR N(Cartesian), Nshots OR N(Cartesian)]"
    )
    parser.add_argument("fmap", help="Field map in rad/s")
    parser.add_argument("ORCmethod",
                        choices=['CPR', 'fsCPR', 'MFI'],
                        help="Correction method")
    parser.add_argument("--cart",
                        default=0,
                        choices=[0, 1],
                        dest="cart_opt",
                        help="0: non-cartesian data, 1: cartesian data")
    parser.add_argument(
        '--Lx',
        default=2,
        dest='Lx',
        help="L(frequency bins) factor with respect to minimum L")
    parser.add_argument('--grad_raster',
                        default=10e-6,
                        dest='grad_raster',
                        help="Gradient raster time")
    parser.add_argument('--TE', default=0, dest='TE', help="Echo time")
    parser.add_argument(
        '--dcf',
        default=None,
        dest='dcf',
        help="Density compensation factor for non-cartesian trajectories")

    args = parser.parse_args()
    data_in = get_data_from_file(os.path.join(args.path, args.datain))
    ktraj = get_data_from_file(os.path.join(args.path, args.ktraj))
    fmap = get_data_from_file(os.path.join(args.path,
                                           args.fmap)) / (2 * math.pi)

    ## Dimensions check
    print('Checking dimensions...')
    if data_in.shape[0] != ktraj.shape[0] or data_in.shape[1] != ktraj.shape[1]:
        raise ValueError(
            'The raw data does not agree with the k-space trajectory')
    if fmap.shape[0] != fmap.shape[1]:
        raise ValueError(
            'Image and field map should have square dimensions (NxN)')
    if len(fmap.shape) > 2:
        if fmap.shape[-1] != data_in.shape[2]:
            raise ValueError(
                'The field map dimensions do not agree with the raw data')
    if len(fmap.shape) == 2:
        fmap = np.expand_dims(fmap, -1)
        if data_in.shape[2] != 1:
            raise ValueError(
                'The field map dimensions do not agree with the raw data')

    print('OK')

    ## Sequence parameters
    N = fmap.shape[0]
    Npoints = data_in.shape[0]
    Nshots = data_in.shape[1]
    Nslices = data_in.shape[-2]
    Nchannels = data_in.shape[-1]

    t_ro = Npoints * args.grad_raster
    T = np.linspace(args.TE, args.TE + t_ro, Npoints).reshape((Npoints, 1))
    if args.cart_opt == 0:
        seq_params = {
            'N': N,
            'Npoints': Npoints,
            'Nshots': Nshots,
            't_vector': T,
            't_readout': t_ro
        }

        if args.dcf is not None:
            dcf = get_data_from_file(os.path.join(args.path,
                                                  args.dcf)).flatten()
            seq_params.update({'dcf': dcf})

    ORC_result = np.zeros((N, N, Nslices, Nchannels), dtype=complex)

    if args.ORCmethod == 'MFI':
        for ch in tqdm(range(Nchannels)):
            for sl in range(Nslices):
                if args.cart_opt == 0:
                    ORC_result[:, :, sl,
                               ch] = ORC.MFI(np.squeeze(data_in[:, :, sl, ch]),
                                             'raw', ktraj,
                                             np.squeeze(fmap[:, :, sl]),
                                             args.Lx, 1, seq_params)
                elif args.cart_opt == 1:
                    ORC_result[:, :, sl,
                               ch] = ORC.MFI(np.squeeze(data_in[:, :, sl, ch]),
                                             'raw', ktraj,
                                             np.squeeze(fmap[:, :,
                                                             sl]), args.Lx)

    elif args.ORCmethod == 'fsCPR':
        for ch in tqdm(range(Nchannels)):
            for sl in range(Nslices):
                if args.cart == 0:
                    ORC_result[:, :, sl, ch] = ORC.fs_CPR(
                        np.squeeze(data_in[:, :, sl, ch]), 'raw', ktraj,
                        np.squeeze(fmap[:, :, sl]), args.Lx, 1, seq_params)
                elif args.cart == 1:
                    ORC_result[:, :, sl, ch] = ORC.fs_CPR(
                        np.squeeze(data_in[:, :, sl, ch]), 'raw', ktraj,
                        np.squeeze(fmap[:, :, sl]), args.Lx)
    elif args.ORCmethod == 'CPR':
        for ch in tqdm(range(Nchannels)):
            for sl in range(Nslices):
                if args.cart == 0:
                    ORC_result[:, :, sl,
                               ch] = ORC.CPR(np.squeeze(data_in[:, :, sl, ch]),
                                             'raw', ktraj,
                                             np.squeeze(fmap[:, :, sl]), 1,
                                             seq_params)
                elif args.cart == 1:
                    ORC_result[:, :, sl,
                               ch] = ORC.CPR(np.squeeze(data_in[:, :, sl, ch]),
                                             'raw', ktraj,
                                             np.squeeze(fmap[:, :, sl]))

    np.save(os.path.join(args.path, 'ORC_result_' + args.ORCmethod),
            ORC_result)
Example #8
0
        fsCPR_result[:, :, sl,
                     ch] = ORC.fs_CPR(np.squeeze(rawdata[:, :, sl,
                                                         ch]), 'raw', ktraj,
                                      np.squeeze(fmap[:, :, sl]), Lx, 1,
                                      seq_params)
        fsCPR_timing += time.time() - before

        print('fsCPR: Done with slice:' + str(sl + 1) + ', channel:' +
              str(ch + 1))
        np.save(outputs['path_correction_folder'] + 'fsCPR_Lx' + str(Lx),
                fsCPR_result)

        before = time.time()
        MFI_result[:, :, sl,
                   ch] = ORC.MFI(np.squeeze(rawdata[:, :, sl,
                                                    ch]), 'raw', ktraj,
                                 np.squeeze(fmap[:, :, sl]), Lx, 1, seq_params)
        MFI_timing += time.time() - before

        print('MFI: Done with slice:' + str(sl + 1) + ', channel:' +
              str(ch + 1))
        np.save(outputs['path_correction_folder'] + 'MFI_Lx' + str(Lx),
                MFI_result)

##
# Display the results
##
print('\nCPR correction took ' + str(CPR_timing) + ' seconds.')
print('\nFs-CPR correction took ' + str(fsCPR_timing) + ' seconds.')
print('\nMFI correction took ' + str(MFI_timing) + ' seconds.')