Example #1
0
 def calculate(self, data, method=None):
     '''
     Calculates coil sensitivity maps from coil images or sorted 
     acquisitions.
     data  : either AcquisitionData or CoilImages
     method: either SRSS (Square Root of the Sum of Squares, default) or 
             Inati
     '''
     if isinstance(data, AcquisitionData):
         if data.is_sorted() is False:
             print('WARNING: acquisitions may be in a wrong order')
     if self.handle is not None:
         pyiutil.deleteDataHandle(self.handle)
     self.handle = pygadgetron.cGT_CoilSensitivities('')
     check_status(self.handle)
     if method is not None:
         method_name, parm_list = name_and_parameters(method)
         parm = parse_arglist(parm_list)
     else:
         method_name = 'SRSS'
         parm = {}
     if isinstance(data, AcquisitionData):
         assert data.handle is not None
         _set_int_par\
             (self.handle, 'coil_sensitivity', 'smoothness', self.smoothness)
         try_calling(pygadgetron.cGT_computeCoilSensitivities\
             (self.handle, data.handle))
     elif isinstance(data, CoilImageData):
         assert data.handle is not None
         if method_name == 'Inati':
             #                if not HAVE_ISMRMRDTOOLS:
             try:
                 from ismrmrdtools import coils
             except:
                 raise error('Inati method requires ismrmrd-python-tools')
             nz = data.number()
             for z in range(nz):
                 ci = numpy.squeeze(data.as_array(z))
                 (csm, rho) = coils.calculate_csm_inati_iter(ci)
                 self.append(csm.astype(numpy.complex64))
         elif method_name == 'SRSS':
             if 'niter' in parm:
                 nit = int(parm['niter'])
                 _set_int_par\
                     (self.handle, 'coil_sensitivity', 'smoothness', nit)
             try_calling(pygadgetron.cGT_computeCSMsFromCIs\
                 (self.handle, data.handle))
         else:
             raise error('Unknown method %s' % method_name)
     else:
         raise error('Cannot calculate coil sensitivities from %s' % \
                     repr(type(data)))
Example #2
0
def view(image,
         load_opts=None,
         is_raw=None,
         is_line=None,
         prep=None,
         fft=False,
         fft_axes=None,
         fftshift=None,
         avg_axis=None,
         coil_combine_axis=None,
         coil_combine_method='walsh',
         coil_combine_opts=None,
         is_imspace=False,
         mag=None,
         phase=False,
         log=False,
         imshow_opts={'cmap': 'gray'},
         montage_axis=None,
         montage_opts={'padding_width': 2},
         movie_axis=None,
         movie_interval=50,
         movie_repeat=True,
         save_npy=False,
         debug_level=logging.DEBUG,
         test_run=False):
    '''Image viewer to quickly inspect data.

    Parameters
    ----------
    image : str or array_like
        Name of the file including the file extension or numpy array.
    load_opts : dict, optional
        Options to pass to data loader.
    is_raw : bool, optional
        Inform if data is raw. Will attempt to guess from extension.
    is_line : bool, optional
        Whether or not this is a line plot (as opposed to image).
    prep : callable, optional
        Lambda function to process the data before it's displayed.
    fft : bool, optional
        Whether or not to perform n-dimensional FFT of data.
    fft_axes : tuple, optional
        Axis to perform FFT over, determines dimension of n-dim FFT.
    fftshift : bool, optional
        Whether or not to perform fftshift. Defaults to True if fft.
    avg_axis : int, optional
        Take average over given set of axes.
    coil_combine_axis : int, optional
        Which axis to perform coil combination over.
    coil_combine_method : {'walsh', 'inati', 'pca'}, optional
        Method to use to combine coils.
    coil_combine_opts : dict, optional
        Options to pass to the coil combine method.
    is_imspace : bool, optional
        Whether or not the data is in image space. For coil combine.
    mag : bool, optional
        View magnitude image. Defaults to True if data is complex.
    phase : bool, optional
        View phase image.
    log : bool, optional
        View log of magnitude data. Defaults to False.
    imshow_opts : dict, optional
        Options to pass to imshow. Defaults to { 'cmap'='gray' }.
    montage_axis : int, optional
        Which axis is the number of images to be shown.
    montage_opts : dict, optional
        Additional options to pass to the skimage.util.montage.
    movie_axis : int, optional
        Which axis is the number of frames of the movie.
    movie_interval : int, optional
        Interval to give to animation frames.
    movie_repeat : bool, optional
        Whether or not to put movie on endless loop.
    save_npy : bool, optional
        Whether or not to save the output as npy file.
    debug_level : logging_level, optional
        Level of verbosity. See logging module.
    test_run : bool, optional
        Doesn't show figure, returns debug object. Mostly for testing.

    Returns
    -------
    data : array_like
        Image data shown in plot.
    dict, optional
        All local variables when test_run=True.

    Raises
    ------
    Exception
        When file type is not in ['dat', 'npy', 'mat', 'h5'].
    ValueError
        When coil combine requested, but fft_axes not set.
    AssertionError
        When Walsh coil combine requested but len(fft_axes) =/= 2.
    ValueError
        When there are too many dimension to display.
    '''

    # Set up logging...
    logging.basicConfig(format='%(levelname)s: %(message)s', level=debug_level)

    # Add some default empty params
    if load_opts is None:
        load_opts = dict()
    if coil_combine_opts is None:
        coil_combine_opts = dict()

    # If the user wants to look at numpy matrix, recognize that
    # filename is the matrix:
    if isinstance(image, np.ndarray):
        logging.info('Image is a numpy array!')
        data = image
    elif isinstance(image, list):
        # If user sends a list, try casting to numpy array
        logging.info('Image is a list, trying to cast as numpy array...')
        data = np.array(image)
    else:
        # Find the file extension
        ext = pathlib.Path(image).suffix

        # If the user says data is raw, then trust the user
        if is_raw or (ext == '.dat'):
            data = load_raw(image, **load_opts)
        elif ext == '.npy':
            data = np.load(image, **load_opts)
        elif ext == '.mat':
            # Help out the user a little bit...  If only one
            # nontrivial key is found then go ahead and assume it's
            # that one
            data = None
            if not list(load_opts):
                keys = mat_keys(image, no_print=True)
                if len(keys) == 1:
                    logging.info(('No key supplied, but one key for'
                                  ' mat dictionary found (%s), using'
                                  ' it...'), keys[0])
                    data = load_mat(image, key=keys[0])

            # If we can't help the user out, just load it as normal
            if data is None:
                data = load_mat(image, **load_opts)
        elif ext == '.h5':
            data = load_ismrmrd(image, **load_opts)
        else:
            raise Exception('File type %s not understood!' % ext)

    # Right off the bat, remove singleton dimensions
    if 1 in data.shape:
        logging.info('Current shape %s: Removing singleton dimensions...',
                     str(data.shape))
        data = data.squeeze()
        logging.info('New shape: %s', str(data.shape))

    # Average out over any axis specified
    if avg_axis is not None:
        data = np.mean(data, axis=avg_axis)

    # Let's collapse the coil dimension using the specified algorithm
    if coil_combine_axis is not None:

        # We'll need to know the fft_axes if the data is in kspace
        if not is_imspace and fft_axes is None:
            msg = ('fft_axes required to do coil combination of '
                   'k-space data!')
            raise ValueError(msg)

        if coil_combine_method == 'walsh':
            msg = 'Walsh only works with 2D images!'
            assert len(fft_axes) == 2, msg
            logging.info('Performing Walsh 2d coil combine across axis %d...',
                         list(range(data.ndim))[coil_combine_axis])

            # We need to do this is image domain...
            if not is_imspace:
                fft_data = np.fft.ifftshift(np.fft.ifftn(data, axes=fft_axes),
                                            axes=fft_axes)
            else:
                fft_data = data

            # walsh expects (coil,y,x)
            fft_data = np.moveaxis(fft_data, coil_combine_axis, 0)
            csm_walsh, _ = calculate_csm_walsh(fft_data, **coil_combine_opts)
            fft_data = np.sum(csm_walsh * np.conj(fft_data),
                              axis=0,
                              keepdims=True)

            # Sum kept the axis where coil used to be so we can rely
            # on fft_axes to be correct when do the FT back to kspace
            fft_data = np.moveaxis(fft_data, 0, coil_combine_axis)

            # Now move back to kspace and squeeze the dangling axis
            if not is_imspace:
                data = np.fft.fftn(np.fft.fftshift(fft_data, axes=fft_axes),
                                   axes=fft_axes).squeeze()
            else:
                data = fft_data.squeeze()

        elif coil_combine_method == 'inati':

            logging.info('Performing Inati coil combine across axis %d...',
                         list(range(data.ndim))[coil_combine_axis])

            # Put things into image space if we need to
            if not is_imspace:
                fft_data = np.fft.ifftshift(np.fft.ifftn(data, axes=fft_axes),
                                            axes=fft_axes)
            else:
                fft_data = data

            # inati expects (coil,z,y,x)
            fft_data = np.moveaxis(fft_data, coil_combine_axis, 0)
            _, fft_data = calculate_csm_inati_iter(fft_data,
                                                   **coil_combine_opts)

            # calculate_csm_inati_iter got rid of the axis, so we
            # need to add it back in so we can use the same fft_axes
            fft_data = np.expand_dims(fft_data, coil_combine_axis)

            # Now move back to kspace and squeeze the dangling axis
            if not is_imspace:
                data = np.fft.fftn(np.fft.fftshift(fft_data, axes=fft_axes),
                                   axes=fft_axes).squeeze()
            else:
                data = fft_data.squeeze()

        elif coil_combine_method == 'pca':
            logging.info('Performing PCA coil combine across axis %d...',
                         list(range(data.ndim))[coil_combine_axis])

            # We don't actually care whether we do this is in kspace
            # or imspace
            if not is_imspace:
                logging.info(('PCA doesn\'t care that image might not be in'
                              'image space.'))

            if 'n_components' not in coil_combine_opts:
                n_components = int(data.shape[coil_combine_axis] / 2)
                logging.info('Deciding to use %d components.', n_components)
                coil_combine_opts['n_components'] = n_components

            data = coil_pca(data,
                            coil_dim=coil_combine_axis,
                            **coil_combine_opts)

        else:
            logging.error('Coil combination method "%s" not supported!',
                          coil_combine_method)
            logging.warning('Attempting to skip coil combination!')

    # Show the image.  Let's also try to help the user out again.  If
    # we have 3 dimensions, one of them is probably a montage or a
    # movie.  If the user didn't tell us anything, it's going to
    # crash anyway, so let's try guessing what's going on...
    if (data.ndim > 2) and (movie_axis is None) and (montage_axis is None):
        logging.info('Data has %d dimensions!', data.ndim)

        # We will always assume that inplane resolution is larger
        # than the movie/montage dimensions

        # If only 3 dims, then one must be montage/movie dimension
        if data.ndim == 3:
            # assume inplane resolution larger than movie/montage dim
            min_axis = np.argmin(data.shape)

            # Assume 10 is the most we'll want to montage
            if data.shape[min_axis] < 10:
                logging.info('Guessing axis %d is montage...', min_axis)
                montage_axis = min_axis
            else:
                logging.info('Guessing axis %d is movie...', min_axis)
                movie_axis = min_axis

        # If 4 dims, guess smaller dim will be montage, larger guess
        # movie
        elif data.ndim == 4:
            montage_axis = np.argmin(data.shape)

            # Consider the 4th dimension as the color channel in
            # skimontage
            montage_opts['multichannel'] = True

            # Montage will go through skimontage which will remove the
            # montage_axis dimension, so find the movie dimension
            #  without the montage dimension:
            tmp = np.delete(data.shape[:], montage_axis)
            movie_axis = np.argmin(tmp)

            logging.info(('Guessing axis %d is montage, axis %d will be '
                          'movie...'), montage_axis, movie_axis)

    # fft and fftshift will require fft_axes.  If the user didn't
    # give us axes, let's try to guess them:
    if (fft or (fftshift is not False)) and (fft_axes is None):
        all_axes = list(range(data.ndim))

        if (montage_axis is not None) and (movie_axis is not None):
            fft_axes = np.delete(
                all_axes, [all_axes[montage_axis], all_axes[movie_axis]])
        elif montage_axis is not None:
            fft_axes = np.delete(all_axes, all_axes[montage_axis])
        elif movie_axis is not None:
            fft_axes = np.delete(all_axes, all_axes[movie_axis])
        else:
            fft_axes = all_axes

        logging.info('User did not supply fft_axes, guessing %s...',
                     str(fft_axes))

    # Perform n-dim FFT across fft_axes if desired
    if fft:
        data = np.fft.fftn(data, axes=fft_axes)

    # Perform fftshift if desired.  If the user does not specify
    # fftshift, if fft is performed, then fftshift will also be
    # performed.  To override this behavior, simply supply
    # fftshift=False in the arguments.  Similarly, to force fftshift
    # even if no fft was performed, supply fftshift=True.
    if fft and (fftshift is None):
        fftshift = True
    elif fftshift is None:
        fftshift = False

    if fftshift:
        data = np.fft.fftshift(data, axes=fft_axes)

    # Take absolute value to view if necessary, must take abs before
    # log
    if np.iscomplexobj(data) or (mag is True) or (log is True):
        data = np.abs(data)

        if log:
            # Don't take log of 0!
            data[data == 0] = np.nan
            data = np.log(data)

    # If we asked for phase, let's work out how we'll do that
    if phase and ((mag is None) or (mag is True)):
        # TODO: figure out which axis to concatenate the phase onto
        data = np.concatenate((data, np.angle(data)), axis=fft_axes[-1])
    elif phase and (mag is False):
        data = np.angle(data)

    # Run any processing before imshow
    if callable(prep):
        data = prep(data)

    # If it's just a line plot, skip all the montage, movie stuff
    if is_line:
        montage_axis = None
        movie_axis = None

    if montage_axis is not None:
        # We can deal with 4 dimensions if we allow multichannel
        if data.ndim == 4 and 'multichannel' not in montage_opts:
            montage_opts['multichannel'] = True

            # When we move the movie_axis to the end, we will need to
            # adjust the montage axis in case we displace it.  We
            # need to move it to the end so skimontage will consider
            # it the multichannel
            data = np.moveaxis(data, movie_axis, -1)
            if movie_axis < montage_axis:
                montage_axis -= 1

        # Put the montage axis in front
        data = np.moveaxis(data, montage_axis, 0)
        try:
            data = skimontage(data, **montage_opts)
        except ValueError:
            # Multichannel might be erronously set
            montage_opts['multichannel'] = False
            data = skimontage(data, **montage_opts)

        if data.ndim == 3:
            # If we had 4 dimensions, we just lost one, so now we
            # need to know where the movie dimension went off to...
            if movie_axis > montage_axis:
                movie_axis -= 1
            # Move the movie axis back, it's no longer the color
            # channel
            data = np.moveaxis(data, -1, movie_axis)

    if movie_axis is not None:
        fig = plt.figure()
        data = np.moveaxis(data, movie_axis, -1)
        im = plt.imshow(data[..., 0], **imshow_opts)

        def updatefig(frame):
            '''Animation function for figure.'''
            im.set_array(data[..., frame])
            return im,  # pylint: disable=R1707

        _ani = animation.FuncAnimation(fig,
                                       updatefig,
                                       frames=data.shape[-1],
                                       interval=movie_interval,
                                       blit=True,
                                       repeat=movie_repeat)

        if not test_run:
            plt.show()
    else:
        if data.ndim == 1 or is_line:
            plt.plot(data)
        elif data.ndim == 2:
            # Just a regular old 2d image...
            plt.imshow(np.nan_to_num(data), **imshow_opts)
        else:
            raise ValueError('%d is too many dimensions!' % data.ndim)

        if not test_run:
            plt.show()

    # Save what we looked at if desired
    if save_npy:
        if ext:
            filename = image
        else:
            filename = 'view-output'
        np.save(filename, data)

    # If we're testing, return all the local vars
    if test_run:
        return locals()
    return data
Example #3
0
def reconstruct_epi(filename, datasetname, noise, gre):
    
    # Read the epi data
    dset = ismrmrd.Dataset(filename,datasetname)

    ##############################
    # Scan Parameters and Layout #
    ##############################
    header = ismrmrd.xsd.CreateFromDocument(dset.read_xml_header())
    enc = header.encoding[0]
    nkx = enc.encodedSpace.matrixSize.x
    nky = enc.encodedSpace.matrixSize.y
    ncoils = header.acquisitionSystemInformation.receiverChannels
    epi_noise_bw = header.acquisitionSystemInformation.relativeReceiverNoiseBandwidth
    acc_factor = enc.parallelImaging.accelerationFactor.kspace_encoding_step_1
    
    # Number of Slices
    if enc.encodingLimits.slice != None:
        nslices = enc.encodingLimits.slice.maximum + 1
    else:
        nslices = 1

    # Loop through the acquisitions ignoring the noise scans and the
    # parallel imaging calibration scans which are EPI based
    firstscan = 0
    while True:
        acq = dset.read_acquisition(firstscan)
        if acq.isFlagSet(ismrmrd.ACQ_IS_NOISE_MEASUREMENT) or acq.isFlagSet(ismrmrd.ACQ_IS_PARALLEL_CALIBRATION):
            firstscan += 1
        else:
            break

    #print('First imaging scan at:', firstscan)
    nsamp = acq.number_of_samples
    ncoils = acq.active_channels
    sampletime = acq.sample_time_us

    # The lines are labeled with flags as follows:
    # - Noise or Imaging using ACQ_IS_NOISE_MEASUREMENT
    # - Parallel calibration using ACQ_IS_PARALLEL_CALIBRATION
    # - Forward or Reverse using the ACQ_IS_REVERSE flag
    # - EPI navigator using ACQ_IS_PHASECORR_DATA
    # - First or last in a slice using ACQ_FIRST_IN_SLICE and ACQ_LAST_IN_SLICE
    # - The first navigator in a shot is labeled as first in slice
    # - The first imaging line in a shot is labeled as firt in slice
    # - The last imaging line in a show is labeled as last in slice
    # for n in range(firstscan-1,firstscan+60):
    #   acq = dset.read_acquisition(n)
    #   print(acq.idx.kspace_encode_step_1)
    #   if acq.isFlagSet(ismrmrd.ACQ_FIRST_IN_SLICE):
    #       print('First')
    #   elif acq.isFlagSet(ismrmrd.ACQ_LAST_IN_SLICE):
    #       print('Last')
    #   else:
    #       print('Middle')
    #   if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
    #       print('Reverse')
    #   else:
    #       print('Forward')
    #   if acq.isFlagSet(ismrmrd.ACQ_IS_PHASECORR_DATA):
    #       print('Navigator')

    # The EPI trajectory is described in the XML header
    # for o in enc.trajectoryDescription.userParameterLong:
    #    print(o.name, o.value_)
    #    
    # for o in enc.trajectoryDescription.userParameterDouble:
    #     print(o.name, o.value_)
    tup = tdown = tflat = tdelay = nsamp = nnav = etl = 0
    for o in enc.trajectoryDescription.userParameterLong:
        if o.name == 'rampUpTime':
            tup = o.value_
        if o.name == 'rampDownTime':
            tdown = o.value_
        if o.name == 'flatTopTime':
            tflat = o.value_
        if o.name == 'acqDelayTime':
            tdelay = o.value_
        if o.name == 'numSamples':
            nsamp = o.value_
        if o.name == 'numberOfNavigators':
            nnav = o.value_
        if o.name == 'etl':
            etl = o.value_

    #print(tup, tdown, tflat, tdelay, nsamp, nnav, etl)

    ####################################
    # Calculate the gridding operators #
    ####################################
    nkx = enc.encodedSpace.matrixSize.x
    nx = enc.reconSpace.matrixSize.x
    t = tdelay + sampletime*np.arange(nsamp)
    x = np.arange(nx)/nx-0.5
    up = t<=tup
    flat = (t>tup)*(t<(tup+tflat))
    down = t>=(tup+tflat)

    #Integral of trajectory (Gmax=1.0)
    k = np.zeros(nsamp)
    k[up] = 0.5/tup*t[up]**2
    k[flat] = 0.5*tup + (t[flat] - tup)
    k[down] = 0.5*tup + tflat + 0.5*tdown-0.5/tdown*(tup+tflat+tdown-t[down])**2
    #Scale to match resolution
    k *= nkx/(k[-1]-k[0])
    #Center
    k -= k[nsamp//2]
    kpos = k
    kneg = -1.0*k
    #Corresponding even range
    keven = np.arange(nkx)
    keven -= keven[nkx//2]
    #Forward model
    Qpos = np.zeros([nsamp,nkx])
    Qneg = np.zeros([nsamp,nkx])
    for p in range(nsamp):
        Qpos[p,:] = np.sinc(kpos[p]-keven)
        Qneg[p,:] = np.sinc(kneg[p]-keven)
    #Inverse
    Rpos = np.linalg.pinv(Qpos)
    Rneg = np.linalg.pinv(Qneg)
    #Take transpose because we apply from the right
    Rpos = Rpos.transpose()
    Rneg = Rneg.transpose()

    #################################
    # Calculate the kspace filter   #
    # Hanning filter after gridding #
    #################################
    import scipy.signal
    kfiltx = scipy.signal.hann(nkx)
    kfilty = scipy.signal.hann(nky)
    Rpos = np.dot(Rpos, np.diag(kfiltx))
    Rneg = np.dot(Rneg, np.diag(kfiltx))

    ####################################
    # Calculate SENSE unmixing weights #
    ####################################
    # Some basic checks
    if gre.shape[0] != nslices:
        raise ValueError('Calibration and EPI data have different number of slices')
    if gre.shape[1] != ncoils:
        raise ValueError('Calibration and EPI data have different number of coils')

    # Estimate coil sensitivites from the GRE data
    csm_orig = np.zeros(gre.shape,dtype=np.complex)
    for z in range(nslices):
        (csmtmp, actmp, rhotmp) = coils.calculate_csm_inati_iter(gre[z,:,:,:])
        weight = rhotmp**2 / (rhotmp**2 + .01*np.median(rhotmp.ravel())**2)
        csm_orig[z,:,:,:] = csmtmp*weight
 
    # Deal with difference in resolution
    # Up/down sample the coil sensitivities to the resolution of the EPI
    xcsm = np.arange(gre.shape[3])/gre.shape[3]
    ycsm = np.arange(gre.shape[2])/gre.shape[2]
    xepi = np.arange(nx)/nx
    yepi = np.arange(nky)/nky
    csm = np.zeros([nslices,ncoils,nky,nx],dtype=np.complex)
    for z in range(nslices):
        for c in range(ncoils):
            # interpolate the real part and imaginary part separately
            i_real = interp.RectBivariateSpline(ycsm,xcsm,np.real(csm_orig[z,c,:,:]))
            i_imag = interp.RectBivariateSpline(ycsm,xcsm,np.imag(csm_orig[z,c,:,:]))
            csm[z,c,:,:] = i_real(yepi,xepi) + 1j*i_imag(yepi,xepi)

    # SENSE weights
    unmix = np.zeros(csm.shape,dtype=np.complex)
    for z in range(nslices):
        unmix[z,:,:,:] = sense.calculate_sense_unmixing(acc_factor, csm[z,:,:,:])[0]
    
    ###############
    # Reconstruct #
    ###############
    # Initialize the array for a volume's worth of data
    H = np.zeros([nslices, ncoils, nky, nx],dtype=np.complex)
    # Loop over the slices
    scan = firstscan
    for z in range(nslices):
        #print('Slice %d starts at scan %d.'%(z,scan))
        # Navigator 1
        acq = dset.read_acquisition(scan)
        #print(scan,acq.idx.slice,acq.idx.kspace_encode_step_1,acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE))
        currslice = acq.idx.slice # keep track of the slice number
        data = coils.apply_prewhitening(acq.data,noise.preWMtx)
        if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
            rnav1 = transform.transform_kspace_to_image(np.dot(data, Rneg),dim=[1])
            sgn = -1.0
        else:
            rnav1 = transform.transform_kspace_to_image(np.dot(data, Rpos),dim=[1])
            sgn = 1.0
        scan += 1

        # Navigator 2
        acq = dset.read_acquisition(scan)
        #print(scan,acq.idx.slice,acq.idx.kspace_encode_step_1,acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE))
        data = coils.apply_prewhitening(acq.data,noise.preWMtx)
        if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
            rnav2 = transform.transform_kspace_to_image(np.dot(data, Rneg),dim=[1])
        else:
            rnav2 = transform.transform_kspace_to_image(np.dot(data, Rpos),dim=[1])
        scan += 1

        # Navigator 3
        acq = dset.read_acquisition(scan)
        #print(scan,acq.idx.slice,acq.idx.kspace_encode_step_1,acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE))
        data = coils.apply_prewhitening(acq.data,noise.preWMtx)
        if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
            rnav3 = transform.transform_kspace_to_image(np.dot(data, Rneg),dim=[1])
        else:
            rnav3 = transform.transform_kspace_to_image(np.dot(data, Rpos),dim=[1])
        scan += 1

        # Phase correction
        delta = np.conj(rnav1+rnav3) * rnav2
        fdelta = np.tile(np.mean(delta,axis=0),[ncoils,1])
        corr = np.exp(sgn*1j*np.angle(np.sqrt(fdelta)))

        for j in range(nky):
            acq = dset.read_acquisition(scan)
            slice = acq.idx.slice              
            if slice != currslice:
                # end of this slice
                break

            ky = acq.idx.kspace_encode_step_1
            data = coils.apply_prewhitening(acq.data,noise.preWMtx)
            if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
                rho = transform.transform_kspace_to_image(np.dot(data, Rneg),dim=[1])
                H[slice,:,ky,:] = kfilty[ky]*np.conj(corr)*rho
            else:
                rho = transform.transform_kspace_to_image(np.dot(data, Rpos),dim=[1])
                H[slice,:,ky,:] = kfilty[ky]*corr*rho        
            scan += 1

    # Close the data set
    dset.close()
    
    # Recon in along y
    H = transform.transform_kspace_to_image(H,dim=[2])
    
    # Combine with SENSE weights
    epi_im = np.abs(np.squeeze(np.sum(H*unmix,axis=1)))
    
    return epi_im
Example #4
0
from mr_utils import view
from ismrmrdtools.coils import calculate_csm_inati_iter, calculate_csm_walsh

if __name__ == '__main__':

    im0 = np.load('data/20190401_GASP_PHANTOM/set2_gre_tr34_te2_87.npy')
    im0 = np.mean(im0, axis=2)
    im0 = np.moveaxis(im0, -1, 0)

    im1 = np.load('data/20190401_GASP_PHANTOM/set2_gre_tr4_te5_74.npy')
    im1 = np.mean(im1, axis=2)
    im1 = np.moveaxis(im1, -1, 0)

    # Make a field map coil by coil
    fm0 = dual_echo_gre(im0, im1, 2.87e-3, 5.74e-3)
    np.save('data/20190401_GASP_PHANTOM/coil_fm_gre.npy', fm0)
    view(fm0)
    fm0 = np.mean(fm0, axis=0)

    # Coil combine im0 and im1 then get field map
    _, im0cc0 = calculate_csm_inati_iter(im0)
    _, im1cc0 = calculate_csm_inati_iter(im1)
    csm, _ = calculate_csm_walsh(im0)
    im0cc1 = np.sum(np.conj(im0) * csm, axis=0)
    csm, _ = calculate_csm_walsh(im1)
    im1cc1 = np.sum(np.conj(im1) * csm, axis=0)
    fm1 = dual_echo_gre(im0cc0, im1cc0, 2.87e-3, 5.74e-3)
    fm2 = dual_echo_gre(im0cc1, im1cc1, 2.87e-3, 5.74e-3)

    # Compare
    view(np.stack((fm0, fm1, fm2)))
Example #5
0
        f.write(struct.pack('d' * len(datawrite), *datawrite))
        f.close()
        print "duration: ", (time.time() - t0)
#================================================================================
#
#
#================================================================================
# Coil combination
#================================================================================
if args.calcCOILS:
    if eNz > 1:
        if navg > 1:
            coil_images = np.sum(all_data, 1)
            coil_images = transform.fftn(
                np.squeeze(coil_images[1, 1, 1, 1, 1, 1, 1, :, :, :, :]))
            (csm, rho) = coils.calculate_csm_inati_iter(coil_images)
    else:
        coil_images = transform.transform_kspace_to_image(
            np.squeeze(np.mean(all_data, 0)), (1, 2))
        (csm, rho) = coils.calculate_csm_inati_iter(coil_images)

    outname = './' + args.output + '_csm'
    sio.savemat(outname, {'csm': csm})
    outname = './' + args.output + '_rho'
    sio.savemat(outname, {'rho': rho})
#================================================================================
#
#
#================================================================================
# SOS Reconstruction
#================================================================================
Example #6
0
def comparison_numerical_phantom(SNR=None):
    '''Compare coil by coil, Walsh method, and Inati iterative method.'''

    true_im = get_true_im_numerical_phantom()
    csms = get_coil_sensitivity_maps()
    params = get_numerical_phantom_params(SNR=SNR)
    pc_vals = params['pc_vals']
    dim = params['dim']
    noise_std = params['noise_std']
    coil_nums = params['coil_nums']

    # We want to solve gs_recon for each coil we have in the pc set
    err = np.zeros((5, len(csms)))
    rip = err.copy()
    for ii, csm in enumerate(csms):

        # I have coil sensitivities, now I need images to apply them to.
        # coil_ims: (pc,coil,x,y)
        coil_ims = np.zeros((len(pc_vals), csm.shape[0], dim, dim),
                            dtype='complex')
        for jj, pc in enumerate(pc_vals):
            im = bssfp_2d_cylinder(dims=(dim, dim), phase_cyc=pc)
            im += 1j * im
            coil_ims[jj, ...] = im * csm
            coil_ims[jj, ...] += np.random.normal(0, noise_std, coil_ims[
                jj, ...].shape) / 2 + 1j * np.random.normal(
                    0, noise_std, coil_ims[jj, ...].shape) / 2

        # Solve the gs_recon coil by coil
        coil_ims_gs = np.zeros((csm.shape[0], dim, dim), dtype='complex')
        for kk in range(csm.shape[0]):
            coil_ims_gs[kk, ...] = gs_recon(*[
                x.squeeze()
                for x in np.split(coil_ims[:, kk, ...], len(pc_vals))
            ])
        coil_ims_gs[np.isnan(coil_ims_gs)] = 0

        # Easy way out: combine all the coils using sos
        im_est_sos = sos(coil_ims_gs)
        # view(im_est_sos)

        # Take coil by coil solution and do Walsh on it to collapse coil dim
        # walsh
        csm_walsh, _ = calculate_csm_walsh(coil_ims_gs)
        im_est_recon_then_walsh = np.sum(csm_walsh * np.conj(coil_ims_gs),
                                         axis=0)
        im_est_recon_then_walsh[np.isnan(im_est_recon_then_walsh)] = 0
        # view(im_est_recon_then_walsh)

        # inati
        csm_inati, im_est_recon_then_inati = calculate_csm_inati_iter(
            coil_ims_gs)

        # Collapse the coil dimension of each phase-cycle using Walsh,Inati
        pc_est_walsh = np.zeros((len(pc_vals), dim, dim), dtype='complex')
        pc_est_inati = np.zeros((len(pc_vals), dim, dim), dtype='complex')
        for jj in range(len(pc_vals)):
            ## Walsh
            csm_walsh, _ = calculate_csm_walsh(coil_ims[jj, ...])
            pc_est_walsh[jj,
                         ...] = np.sum(csm_walsh * np.conj(coil_ims[jj, ...]),
                                       axis=0)
            # view(csm_walsh)
            # view(pc_est_walsh)

            ## Inati
            csm_inati, pc_est_inati[jj, ...] = calculate_csm_inati_iter(
                coil_ims[jj, ...], smoothing=1)
            # pc_est_inati[jj,...] = np.sum(csm_inati*np.conj(coil_ims[jj,...]),axis=0)
            # view(csm_inati)

        # Now solve the gs_recon using collapsed coils
        im_est_walsh = gs_recon(
            *[x.squeeze() for x in np.split(pc_est_walsh, len(pc_vals))])
        im_est_inati = gs_recon(
            *[x.squeeze() for x in np.split(pc_est_inati, len(pc_vals))])

        # view(im_est_walsh)
        # view(im_est_recon_then_walsh)

        # Compute error metrics
        err[0, ii] = compare_nrmse(im_est_sos, true_im)
        err[1, ii] = compare_nrmse(im_est_recon_then_walsh, true_im)
        err[2, ii] = compare_nrmse(im_est_recon_then_inati, true_im)
        err[3, ii] = compare_nrmse(im_est_walsh, true_im)
        err[4, ii] = compare_nrmse(im_est_inati, true_im)

        im_est_sos[np.isnan(im_est_sos)] = 0
        im_est_recon_then_walsh[np.isnan(im_est_recon_then_walsh)] = 0
        im_est_recon_then_inati[np.isnan(im_est_recon_then_inati)] = 0
        im_est_walsh[np.isnan(im_est_walsh)] = 0
        im_est_inati[np.isnan(im_est_inati)] = 0

        rip[0, ii] = ripple_normal(im_est_sos)
        rip[1, ii] = ripple_normal(im_est_recon_then_walsh)
        rip[2, ii] = ripple_normal(im_est_recon_then_inati)
        rip[3, ii] = ripple_normal(im_est_walsh)
        rip[4, ii] = ripple_normal(im_est_inati)

        # view(im_est_inati)

        # # SOS of the gs solution on each individual coil gives us low periodic
        # # ripple accross the phantom, similar to Walsh method:
        # plt.plot(np.abs(true_im[int(dim/2),:]),'--',label='True Im')
        # plt.plot(np.abs(im_est_sos[int(dim/2),:]),'-.',label='SOS')
        # plt.plot(np.abs(im_est_recon_then_walsh[int(dim/2),:]),label='Recon then Walsh')
        # plt.plot(np.abs(im_est_walsh[int(dim/2),:]),label='Walsh then Recon')
        # # plt.plot(np.abs(im_est_inati[int(dim/2),:]),label='Inati')
        # plt.legend()
        # plt.show()

    # # Let's show some stuff
    # plt.plot(coil_nums,err[0,:],'*-',label='SOS')
    # plt.plot(coil_nums,err[1,:],label='Recon then Walsh')
    # plt.plot(coil_nums,err[2,:],label='Walsh then Recon')
    # # plt.plot(coil_nums,err[3,:],label='Inati')
    # plt.legend()
    # plt.show()

    print('SOS RMSE:', np.mean(err[0, :]))
    print('recon then walsh RMSE:', np.mean(err[1, :]))
    print('recon then inati RMSE:', np.mean(err[2, :]))
    print('walsh then recon RMSE:', np.mean(err[3, :]))
    print('inati then recon RMSE:', np.mean(err[4, :]))

    print('SOS ripple:', np.mean(err[0, :]))
    print('recon then walsh ripple:', np.mean(rip[1, :]))
    print('recon then inati ripple:', np.mean(rip[2, :]))
    print('walsh then recon ripple:', np.mean(rip[3, :]))
    print('inati then recon ripple:', np.mean(rip[4, :]))

    view(im_est_recon_then_walsh[int(dim / 2), :])
    view(im_est_recon_then_inati[int(dim / 2), :])
    view(im_est_walsh[int(dim / 2), :])
    view(im_est_inati[int(dim / 2), :])
    # view(im_est_inati)

    # view(np.stack((im_est_recon_then_walsh,im_est_recon_then_inati,im_est_walsh,im_est_inati)))

    return (err)
Example #7
0
def comparison_knee():
    '''Coil by coil, Walsh method, and Inati iterative method for knee data.'''

    # Load the knee data
    dir = '/home/nicholas/Documents/rawdata/SSFP_SPECTRA_dphiOffset_08022018/'
    files = [
        'meas_MID362_TRUFI_STW_TE3_FID29379',
        'meas_MID363_TRUFI_STW_TE3_dphi_45_FID29380',
        'meas_MID364_TRUFI_STW_TE3_dphi_90_FID29381',
        'meas_MID365_TRUFI_STW_TE3_dphi_135_FID29382',
        'meas_MID366_TRUFI_STW_TE3_dphi_180_FID29383',
        'meas_MID367_TRUFI_STW_TE3_dphi_225_FID29384',
        'meas_MID368_TRUFI_STW_TE3_dphi_270_FID29385',
        'meas_MID369_TRUFI_STW_TE3_dphi_315_FID29386'
    ]
    pc_vals = [0, 45, 90, 135, 180, 225, 270, 315]
    dims = (512, 256)
    num_coils = 4
    num_avgs = 16

    # # Load in raw once, then save as npy with collapsed avg dimension
    # pcs = np.zeros((len(files),dims[0],dims[1],num_coils),dtype='complex')
    # for ii,file in enumerate(files):
    #     pcs[ii,...] = np.mean(load_raw('%s/%s.dat' % (dir,file),use='s2i'),axis=-1)
    # np.save('%s/te3.npy' % dir,pcs)

    # pcs looks like (pc,x,y,coil)
    pcs = np.load('%s/te3.npy' % dir)
    pcs = np.fft.fftshift(np.fft.fft2(pcs, axes=(1, 2)), axes=(1, 2))
    # print(pcs.shape)
    # view(pcs,fft=True,montage_axis=0,movie_axis=3)

    # Do recon then coil combine
    coils0 = np.zeros((pcs.shape[-1], pcs.shape[1], pcs.shape[2]),
                      dtype='complex')
    coils1 = coils0.copy()
    for ii in range(pcs.shape[-1]):
        # We have two sets: 0,90,180,27 and 45,135,225,315
        idx0 = [0, 2, 4, 6]
        idx1 = [1, 3, 5, 7]
        coils0[ii, ...] = gs_recon(*[x.squeeze() for x in pcs[idx0, :, :, ii]])
        coils1[ii, ...] = gs_recon(*[x.squeeze() for x in pcs[idx1, :, :, ii]])
    # Then do the coil combine
    csm_walsh, _ = calculate_csm_walsh(coils0)
    im_est_recon_then_walsh0 = np.sum(csm_walsh * np.conj(coils0), axis=0)
    # view(im_est_recon_then_walsh0)

    csm_walsh, _ = calculate_csm_walsh(coils1)
    im_est_recon_then_walsh1 = np.sum(csm_walsh * np.conj(coils1), axis=0)
    # view(im_est_recon_then_walsh1)

    rip0 = ripple(im_est_recon_then_walsh0)
    rip1 = ripple(im_est_recon_then_walsh1)
    print('recon then walsh: ', np.mean([rip0, rip1]))

    # Now try inati
    csm_inati, im_est_recon_then_inati0 = calculate_csm_inati_iter(coils0,
                                                                   smoothing=5)
    csm_inati, im_est_recon_then_inati1 = calculate_csm_inati_iter(coils1,
                                                                   smoothing=5)
    rip0 = ripple(im_est_recon_then_inati0)
    rip1 = ripple(im_est_recon_then_inati1)
    print('recon then inati: ', np.mean([rip0, rip1]))

    # Now try sos
    im_est_recon_then_sos0 = sos(coils0, axes=0)
    im_est_recon_then_sos1 = sos(coils1, axes=0)
    rip0 = ripple(im_est_recon_then_sos0)
    rip1 = ripple(im_est_recon_then_sos1)
    print('recon then sos: ', np.mean([rip0, rip1]))
    # view(im_est_recon_then_sos)

    ## Now the other way, combine then recon
    pcs0 = np.zeros((2, pcs.shape[0], pcs.shape[1], pcs.shape[2]),
                    dtype='complex')
    pcs1 = pcs0.copy()
    for ii in range(pcs.shape[0]):
        # Walsh it up
        csm_walsh, _ = calculate_csm_walsh(pcs[ii, ...].transpose((2, 0, 1)))
        pcs0[0, ii, ...] = np.sum(csm_walsh * np.conj(pcs[ii, ...].transpose(
            (2, 0, 1))),
                                  axis=0)
        # view(pcs0[ii,...])

        # Inati it up
        csm_inati, pcs0[1, ii,
                        ...] = calculate_csm_inati_iter(pcs[ii, ...].transpose(
                            (2, 0, 1)),
                                                        smoothing=5)

    ## Now perform gs_recon on each coil combined set
    # Walsh
    im_est_walsh_then_recon0 = gs_recon(
        *[x.squeeze() for x in pcs0[0, idx0, ...]])
    im_est_walsh_then_recon1 = gs_recon(
        *[x.squeeze() for x in pcs0[0, idx1, ...]])
    # Inati
    im_est_inati_then_recon0 = gs_recon(
        *[x.squeeze() for x in pcs0[1, idx0, ...]])
    im_est_inati_then_recon1 = gs_recon(
        *[x.squeeze() for x in pcs0[1, idx1, ...]])

    # view(im_est_walsh_then_recon0)
    # view(im_est_walsh_then_recon1)
    view(im_est_inati_then_recon0)
    view(im_est_inati_then_recon1)

    rip0 = ripple(im_est_walsh_then_recon0)
    rip1 = ripple(im_est_walsh_then_recon1)
    print('walsh then recon: ', np.mean([rip0, rip1]))

    rip0 = ripple(im_est_inati_then_recon0)
    rip1 = ripple(im_est_inati_then_recon1)
    print('inati then recon: ', np.mean([rip0, rip1]))
Example #8
0
def reconstruct_epi(filename, datasetname, noise, gre):

    # Read the epi data
    dset = ismrmrd.Dataset(filename, datasetname)

    ##############################
    # Scan Parameters and Layout #
    ##############################
    header = ismrmrd.xsd.CreateFromDocument(dset.read_xml_header())
    enc = header.encoding[0]
    nkx = enc.encodedSpace.matrixSize.x
    nky = enc.encodedSpace.matrixSize.y
    ncoils = header.acquisitionSystemInformation.receiverChannels
    epi_noise_bw = header.acquisitionSystemInformation.relativeReceiverNoiseBandwidth
    acc_factor = enc.parallelImaging.accelerationFactor.kspace_encoding_step_1

    # Number of Slices
    if enc.encodingLimits.slice != None:
        nslices = enc.encodingLimits.slice.maximum + 1
    else:
        nslices = 1

    # Loop through the acquisitions ignoring the noise scans and the
    # parallel imaging calibration scans which are EPI based
    firstscan = 0
    while True:
        acq = dset.read_acquisition(firstscan)
        if acq.isFlagSet(ismrmrd.ACQ_IS_NOISE_MEASUREMENT) or acq.isFlagSet(
                ismrmrd.ACQ_IS_PARALLEL_CALIBRATION):
            firstscan += 1
        else:
            break

    #print('First imaging scan at:', firstscan)
    nsamp = acq.number_of_samples
    ncoils = acq.active_channels
    sampletime = acq.sample_time_us

    # The lines are labeled with flags as follows:
    # - Noise or Imaging using ACQ_IS_NOISE_MEASUREMENT
    # - Parallel calibration using ACQ_IS_PARALLEL_CALIBRATION
    # - Forward or Reverse using the ACQ_IS_REVERSE flag
    # - EPI navigator using ACQ_IS_PHASECORR_DATA
    # - First or last in a slice using ACQ_FIRST_IN_SLICE and ACQ_LAST_IN_SLICE
    # - The first navigator in a shot is labeled as first in slice
    # - The first imaging line in a shot is labeled as firt in slice
    # - The last imaging line in a show is labeled as last in slice
    # for n in range(firstscan-1,firstscan+60):
    #   acq = dset.read_acquisition(n)
    #   print(acq.idx.kspace_encode_step_1)
    #   if acq.isFlagSet(ismrmrd.ACQ_FIRST_IN_SLICE):
    #       print('First')
    #   elif acq.isFlagSet(ismrmrd.ACQ_LAST_IN_SLICE):
    #       print('Last')
    #   else:
    #       print('Middle')
    #   if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
    #       print('Reverse')
    #   else:
    #       print('Forward')
    #   if acq.isFlagSet(ismrmrd.ACQ_IS_PHASECORR_DATA):
    #       print('Navigator')

    # The EPI trajectory is described in the XML header
    # for o in enc.trajectoryDescription.userParameterLong:
    #    print(o.name, o.value_)
    #
    # for o in enc.trajectoryDescription.userParameterDouble:
    #     print(o.name, o.value_)
    tup = tdown = tflat = tdelay = nsamp = nnav = etl = 0
    for o in enc.trajectoryDescription.userParameterLong:
        if o.name == 'rampUpTime':
            tup = o.value_
        if o.name == 'rampDownTime':
            tdown = o.value_
        if o.name == 'flatTopTime':
            tflat = o.value_
        if o.name == 'acqDelayTime':
            tdelay = o.value_
        if o.name == 'numSamples':
            nsamp = o.value_
        if o.name == 'numberOfNavigators':
            nnav = o.value_
        if o.name == 'etl':
            etl = o.value_

    #print(tup, tdown, tflat, tdelay, nsamp, nnav, etl)

    ####################################
    # Calculate the gridding operators #
    ####################################
    nkx = enc.encodedSpace.matrixSize.x
    nx = enc.reconSpace.matrixSize.x
    t = tdelay + sampletime * np.arange(nsamp)
    x = np.arange(nx) / nx - 0.5
    up = t <= tup
    flat = (t > tup) * (t < (tup + tflat))
    down = t >= (tup + tflat)

    #Integral of trajectory (Gmax=1.0)
    k = np.zeros(nsamp)
    k[up] = 0.5 / tup * t[up]**2
    k[flat] = 0.5 * tup + (t[flat] - tup)
    k[down] = 0.5 * tup + tflat + 0.5 * tdown - 0.5 / tdown * (
        tup + tflat + tdown - t[down])**2
    #Scale to match resolution
    k *= nkx / (k[-1] - k[0])
    #Center
    k -= k[nsamp // 2]
    kpos = k
    kneg = -1.0 * k
    #Corresponding even range
    keven = np.arange(nkx)
    keven -= keven[nkx // 2]
    #Forward model
    Qpos = np.zeros([nsamp, nkx])
    Qneg = np.zeros([nsamp, nkx])
    for p in range(nsamp):
        Qpos[p, :] = np.sinc(kpos[p] - keven)
        Qneg[p, :] = np.sinc(kneg[p] - keven)
    #Inverse
    Rpos = np.linalg.pinv(Qpos)
    Rneg = np.linalg.pinv(Qneg)
    #Take transpose because we apply from the right
    Rpos = Rpos.transpose()
    Rneg = Rneg.transpose()

    #################################
    # Calculate the kspace filter   #
    # Hanning filter after gridding #
    #################################
    import scipy.signal
    kfiltx = scipy.signal.hann(nkx)
    kfilty = scipy.signal.hann(nky)
    Rpos = np.dot(Rpos, np.diag(kfiltx))
    Rneg = np.dot(Rneg, np.diag(kfiltx))

    ####################################
    # Calculate SENSE unmixing weights #
    ####################################
    # Some basic checks
    if gre.shape[0] != nslices:
        raise ValueError(
            'Calibration and EPI data have different number of slices')
    if gre.shape[1] != ncoils:
        raise ValueError(
            'Calibration and EPI data have different number of coils')

    # Estimate coil sensitivites from the GRE data
    csm_orig = np.zeros(gre.shape, dtype=np.complex)
    for z in range(nslices):
        (csmtmp, actmp,
         rhotmp) = coils.calculate_csm_inati_iter(gre[z, :, :, :])
        weight = rhotmp**2 / (rhotmp**2 + .01 * np.median(rhotmp.ravel())**2)
        csm_orig[z, :, :, :] = csmtmp * weight

    # Deal with difference in resolution
    # Up/down sample the coil sensitivities to the resolution of the EPI
    xcsm = np.arange(gre.shape[3]) / gre.shape[3]
    ycsm = np.arange(gre.shape[2]) / gre.shape[2]
    xepi = np.arange(nx) / nx
    yepi = np.arange(nky) / nky
    csm = np.zeros([nslices, ncoils, nky, nx], dtype=np.complex)
    for z in range(nslices):
        for c in range(ncoils):
            # interpolate the real part and imaginary part separately
            i_real = interp.RectBivariateSpline(ycsm, xcsm,
                                                np.real(csm_orig[z, c, :, :]))
            i_imag = interp.RectBivariateSpline(ycsm, xcsm,
                                                np.imag(csm_orig[z, c, :, :]))
            csm[z, c, :, :] = i_real(yepi, xepi) + 1j * i_imag(yepi, xepi)

    # SENSE weights
    unmix = np.zeros(csm.shape, dtype=np.complex)
    for z in range(nslices):
        unmix[z, :, :, :] = sense.calculate_sense_unmixing(
            acc_factor, csm[z, :, :, :])[0]

    ###############
    # Reconstruct #
    ###############
    # Initialize the array for a volume's worth of data
    H = np.zeros([nslices, ncoils, nky, nx], dtype=np.complex)
    # Loop over the slices
    scan = firstscan
    for z in range(nslices):
        #print('Slice %d starts at scan %d.'%(z,scan))
        # Navigator 1
        acq = dset.read_acquisition(scan)
        #print(scan,acq.idx.slice,acq.idx.kspace_encode_step_1,acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE))
        currslice = acq.idx.slice  # keep track of the slice number
        data = coils.apply_prewhitening(acq.data, noise.preWMtx)
        if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
            rnav1 = transform.transform_kspace_to_image(np.dot(data, Rneg),
                                                        dim=[1])
            sgn = -1.0
        else:
            rnav1 = transform.transform_kspace_to_image(np.dot(data, Rpos),
                                                        dim=[1])
            sgn = 1.0
        scan += 1

        # Navigator 2
        acq = dset.read_acquisition(scan)
        #print(scan,acq.idx.slice,acq.idx.kspace_encode_step_1,acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE))
        data = coils.apply_prewhitening(acq.data, noise.preWMtx)
        if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
            rnav2 = transform.transform_kspace_to_image(np.dot(data, Rneg),
                                                        dim=[1])
        else:
            rnav2 = transform.transform_kspace_to_image(np.dot(data, Rpos),
                                                        dim=[1])
        scan += 1

        # Navigator 3
        acq = dset.read_acquisition(scan)
        #print(scan,acq.idx.slice,acq.idx.kspace_encode_step_1,acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE))
        data = coils.apply_prewhitening(acq.data, noise.preWMtx)
        if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
            rnav3 = transform.transform_kspace_to_image(np.dot(data, Rneg),
                                                        dim=[1])
        else:
            rnav3 = transform.transform_kspace_to_image(np.dot(data, Rpos),
                                                        dim=[1])
        scan += 1

        # Phase correction
        delta = np.conj(rnav1 + rnav3) * rnav2
        fdelta = np.tile(np.mean(delta, axis=0), [ncoils, 1])
        corr = np.exp(sgn * 1j * np.angle(np.sqrt(fdelta)))

        for j in range(nky):
            acq = dset.read_acquisition(scan)
            slice = acq.idx.slice
            if slice != currslice:
                # end of this slice
                break

            ky = acq.idx.kspace_encode_step_1
            data = coils.apply_prewhitening(acq.data, noise.preWMtx)
            if acq.isFlagSet(ismrmrd.ACQ_IS_REVERSE):
                rho = transform.transform_kspace_to_image(np.dot(data, Rneg),
                                                          dim=[1])
                H[slice, :, ky, :] = kfilty[ky] * np.conj(corr) * rho
            else:
                rho = transform.transform_kspace_to_image(np.dot(data, Rpos),
                                                          dim=[1])
                H[slice, :, ky, :] = kfilty[ky] * corr * rho
            scan += 1

    # Close the data set
    dset.close()

    # Recon in along y
    H = transform.transform_kspace_to_image(H, dim=[2])

    # Combine with SENSE weights
    epi_im = np.abs(np.squeeze(np.sum(H * unmix, axis=1)))

    return epi_im
# -*- coding: utf-8 -*-

#%%
#Basic setup
import time
import numpy as np
from ismrmrdtools import simulation, coils, show

matrix_size = 256
csm = simulation.generate_birdcage_sensitivities(matrix_size)
phan = simulation.phantom(matrix_size)
coil_images = phan[np.newaxis, :, :] * csm
show.imshow(abs(coil_images), tile_shape=(4, 2))

tstart = time.time()
(csm_est, rho) = coils.calculate_csm_walsh(coil_images)
print("Walsh coil estimation duration: {}s".format(time.time() - tstart))
combined_image = np.sum(csm_est * coil_images, axis=0)

show.imshow(abs(csm_est), tile_shape=(4, 2), scale=(0, 1))
show.imshow(abs(combined_image), scale=(0, 1))

tstart = time.time()
(csm_est2, rho2) = coils.calculate_csm_inati_iter(coil_images)
print("Inati coil estimation duration: {}s".format(time.time() - tstart))
combined_image2 = np.sum(csm_est2 * coil_images, axis=0)

show.imshow(abs(csm_est2), tile_shape=(4, 2), scale=(0, 1))
show.imshow(abs(combined_image2), scale=(0, 1))