Ejemplo n.º 1
0
    def compute(self):

        import numpy as np

        dps = self.getVal('Dims per Set')
        mtx_xy = self.getVal('Effective MTX XY')
        mtx_z  = self.getVal('Effective MTX Z')
        numiter = self.getVal('Iterations')
        taper = self.getVal('Taper')
        kradscale = self.getVal('krad Scale')
        crds   = self.getData('crds')
        inwates  = self.getData('wates')

        mtxsz_xy = (2.*mtx_xy)+6
        mtxsz_z  = (2.*mtx_z)+6

        sdshape = crds[...,0].shape
        maxi = crds.ndim - 2
        npts = 1
        for i in range(dps):
          npts *= sdshape[maxi-i]
        nsets = int(crds[...,0].size/npts)

        # Reshape crds to be 3 dimensional - # sets, # pts/set, and dimensionality (1-3)
        crds = np.reshape(crds.astype(np.float64),(nsets,npts,crds.shape[-1]))

        if self.getVal('computenow'):

          if inwates is not None:
            wates = np.copy(inwates.reshape(nsets,npts).astype(np.float64))
          else:
            wates = np.ones((nsets,npts), dtype=np.float64)

          # import in thread to save namespace 
          import core.gridding.sdc as sd
          for set in range(nsets):
            if crds.shape[-1] == 1:
              cmtxdim = np.array([mtxsz_xy],dtype=np.int64)
              sdcset = sd.oned_sdc(crds[set,:],wates[set,:],cmtxdim,numiter,taper)
            if crds.shape[-1] == 2:
              cmtxdim = np.array([mtxsz_xy,mtxsz_xy],dtype=np.int64)
              sdcset = sd.twod_sdc(crds[set,:],wates[set,:],cmtxdim,numiter,taper)
            if crds.shape[-1] == 3:
              cmtxdim = np.array([mtxsz_xy,mtxsz_xy,mtxsz_z],dtype=np.int64)
              sdcset = sd.threed_sdc(crds[set,:],wates[set,:],cmtxdim,numiter,taper,kradscale)
            if set == 0:
              sdc = np.expand_dims(sdcset,0)
            else:
              sdc = np.append(sdc,np.expand_dims(sdcset,0),axis=0)

          # Reshape sdc weights to match that of incoming coordinates
          self.setData('sdc', np.reshape(sdc,sdshape))

        return(0)
    def compute(self):
        import numpy as np
        from scipy import linalg

        self.log.node("Virtual Channels node running compute()")
        twoD_or_threeD = 2

        # GETTING WIDGET INFO
        mtx_xy = self.getVal('mtx')
        image_ceiling = self.getVal('image ceiling')
        crop_left = self.getVal('crop left')
        crop_right = self.getVal('crop right')
        crop_top = self.getVal('crop top')
        crop_bottom = self.getVal('crop bottom')
        reset_compute_button = self.getVal('reset compute button each time')
        compute = self.getVal('compute')
        # number of virtual channels m
        m = self.getVal('virtual channels')
        numiter = self.getVal('SDC Iterations')
        
        # GETTING PORT INFO
        data = self.getData('data').astype(np.complex64, copy=False)
        noise = self.getData('noise')
        sensitivity_map_uncropped = self.getData('sensitivity map')
        param = self.getData('params_in')
        
        # set dimensions
        nr_points = data.shape[-1]
        nr_arms = data.shape[-2]
        nr_coils = data.shape[0]
        if data.ndim == 3:
            extra_dim1 = 1
            extra_dim2 = 1
            data.shape = [nr_coils,extra_dim2,extra_dim1,nr_arms,nr_points]
        elif data.ndim == 4:
            extra_dim1 = data.shape[-3]
            extra_dim2 = 1
            data.shape = [nr_coils,extra_dim2,extra_dim1,nr_arms,nr_points]
        elif data.ndim == 5:
            extra_dim1 = data.shape[-3]
            extra_dim2 = data.shape[-4]
        elif data.ndim > 5:
            self.log.warn("Not implemented yet")
        
        print("data shape: " + str(data.shape))
        

        if sensitivity_map_uncropped is None:
            # if cropping or image scaling sliders were changed then use the previously stored csm instead of calcluating a new one
            has_csm_been_calculated = self.getData('sum of square image')
            if ( (has_csm_been_calculated is not None) and
                ( ('crop left' in self.widgetEvents()) or
                ('crop right' in self.widgetEvents()) or
                ('crop top' in self.widgetEvents()) or
                ('crop bottom' in self.widgetEvents()) or
                ('image ceiling' in self.widgetEvents()) ) ):
                    csm = self.getData('masked and normalized sense map')
                    image = self.getData('sum of square image').copy()
                    if ( (csm is None) or (image is None) ):
                        self.log.warn("This should not happen.")
                        return 1
                    csm_mtx = csm.shape[-1]
            else:
                # calculate auto-calibration B1 maps
                coords = self.getData('coords').astype(np.float32, copy=False)
                if (coords is None):
                    self.log.warn("Either a sensitiviy map or coords to calculate one is required")
                    return 1
                
                import bni.gridding.Kaiser2D_utils as kaiser2D
                # parameters from UI
                UI_width = self.getVal('Autocalibration Width (%)')
                UI_taper = self.getVal('Autocalibration Taper (%)')
                UI_mask_floor = self.getVal('Mask Floor (% of max mag)')
                UI_average_csm = self.getVal('Dynamic data - average all dynamics for csm')
                csm_mtx = np.int(0.01 * UI_width * mtx_xy)
                
                if coords.shape[-3]>100:
                    is_GoldenAngle_data = True
                else:
                    is_GoldenAngle_data = False
                if param is not None:
                    if 'spDYN_GOLDANGLE_ON' in param:
                        if int(param['spDYN_GOLDANGLE_ON'][0]) == 1:
                            is_GoldenAngle_data = True
                        else:
                            is_GoldenAngle_data = False
                print("is GoldenAngle: " + str(is_GoldenAngle_data))
                if is_GoldenAngle_data:
                    nr_arms_cms = self.getVal('# golden angle dynamics for csm')
                else:
                    nr_arms_cms = nr_arms
                self.log.debug("nr_arms_cms: " + str(nr_arms_cms))
                csm_data = data[...,0:nr_arms_cms,:]

                # oversampling: Oversample at the beginning and crop at the end
                oversampling_ratio = 2. #1.375
                mtx = np.int(csm_mtx * oversampling_ratio)
                if mtx%2:
                    mtx+=1
                if oversampling_ratio > 1:
                    mtx_min = np.int((mtx-csm_mtx)/2)
                    mtx_max = mtx_min + csm_mtx
                else:
                    mtx_min = 0
                    mtx_max = mtx

                # average dynamics or cardiac phases for csm
                if ( (extra_dim1 > 1) and UI_average_csm ):
                    csm_data = np.sum(csm_data, axis=2)
                    extra_dim1_csm = 1
                    csm_data.shape = [nr_coils, extra_dim2, extra_dim1_csm, nr_arms_cms, nr_points]
                else:
                    extra_dim1_csm = extra_dim1
                self.log.debug("csm_data shape: " + str(csm_data.shape))
                self.log.debug("coords shape: " + str(coords.shape))

                # coords dimensions: (add 1 dimension as they could have another dimension for golden angle dynamics
                if coords.ndim == 3:
                    coords.shape = [1,nr_arms,nr_points,twoD_or_threeD]
                
                # create low resolution csm
                # cropping the data will make gridding and FFT much faster
                magnitude_one_interleave = np.zeros(nr_points)
                for x in range(nr_points):
                    magnitude_one_interleave[x] = np.sqrt( coords[0,0,x,0]**2 + coords[0,0,x,1]**2)
                within_csm_width_radius = magnitude_one_interleave[:] < (0.01 * UI_width * 0.5) # for BNI spirals should be 0.45 instead of 0.5
                nr_points_csm_width = within_csm_width_radius.sum()
                
                # now set the dimension lists
                out_dims_grid = [nr_coils, extra_dim2, extra_dim1_csm, mtx, nr_arms_cms, nr_points_csm_width]
                out_dims_fft = [nr_coils, extra_dim2, extra_dim1_csm, mtx, mtx]

                csm_data = csm_data[...,0:nr_points_csm_width]
                csm_coords = 1. / (0.01 * UI_width) * coords[...,0:nr_arms_cms,0:nr_points_csm_width,:]
                
                # generate SDC based on number of arms and nr of points being used for csm
                import core.gridding.sdc as sd
                #csm_weights = sd.twod_sdcsp(csm_coords.squeeze().astype(np.float64), numiter, 0.01 * UI_taper, mtx)
                cmtxdim = np.array([mtx,mtx],dtype=np.int64)
                wates = np.ones((nr_arms_cms * nr_points_csm_width), dtype=np.float64)
                coords_for_sdc = csm_coords.astype(np.float64)
                coords_for_sdc.shape = [nr_arms_cms * nr_points_csm_width, twoD_or_threeD]
                csm_weights = sd.twod_sdc(coords_for_sdc, wates, cmtxdim, numiter, 0.01 * UI_taper )
                csm_weights.shape = [1,nr_arms_cms,nr_points_csm_width]

                # pre-calculate Kaiser-Bessel kernel
                kernel_table_size = 800
                kernel = kaiser2D.kaiserbessel_kernel( kernel_table_size, oversampling_ratio)
                
                # pre-calculate the rolloff for the spatial domain
                roll = kaiser2D.rolloff2D(mtx, kernel)
                # Grid
                gridded_kspace = kaiser2D.grid2D(csm_data, csm_coords, csm_weights.astype(np.float32), kernel, out_dims_grid)
                self.setData('debug', gridded_kspace)
                # filter k-space - not needed anymore as SDC taper is used now.
                ## win = kaiser2D.window2(gridded_kspace.shape[-2:], windowpct=UI_taper, widthpct=100)
                ## gridded_kspace *= win
                # FFT
                image_domain = kaiser2D.fft2D(gridded_kspace, dir=0, out_dims_fft=out_dims_fft)
                # rolloff
                image_domain *= roll
                # crop to original matrix size
                csm = image_domain[...,mtx_min:mtx_max,mtx_min:mtx_max]
                # normalize by rms (better would be to use a whole body coil image
                csm_rms = np.sqrt(np.sum(np.abs(csm)**2, axis=0))
                csm = csm / csm_rms
                # zero out points that are below mask threshold
                thresh = 0.01 * UI_mask_floor * csm_rms.max()
                csm *= csm_rms > thresh
                # for ROI selection use csm_rms, which still has some contrast
                image = csm_rms
                image.shape = [csm_mtx,csm_mtx]
                image_sos = image.copy()
                self.setData('sum of square image', image_sos)
        else:
            csm = sensitivity_map_uncropped
            csm_mtx = csm.shape[-1]
            # create sum-of-squares of sensitivity map to allow selection of ROI
            image = np.copy(csm)
            image = np.sqrt(np.sum(np.abs(image)**2, axis=0))
            image.shape = [csm_mtx,csm_mtx]
        self.setData('masked and normalized sense map', csm)

        # display sum-of-squares of sensitivity map to allow selection of ROI
        data_max = image.max()
        data_min = image.min()

        image[:, crop_left-1] = data_max
        image[:, crop_right-1] = data_max
        image[crop_top-1, :] = data_max
        image[crop_bottom-1, :] = data_max
        
        data_range = data_max - data_min
        new_max = data_range * 0.01 * image_ceiling + data_min
        dmask = np.ones(image.shape)
        image = np.minimum(image,new_max*dmask)
        if new_max > data_min:
            image = 255.*(image - data_min)/(new_max-data_min)
        red = green = blue = np.uint8(image)
        alpha = 255. * np.ones(blue.shape)
        h, w = red.shape[:2]
        image1 = np.zeros((h, w, 4), dtype=np.uint8)
        image1[:, :, 0] = red
        image1[:, :, 1] = green
        image1[:, :, 2] = blue
        image1[:, :, 3] = alpha

        format_ = QtGui.QImage.Format_RGB32
        
        image2 = QtGui.QImage(image1.data, w, h, format_)
        image2.ndarry = image1
        self.setAttr('image', val=image2)

        # crop sensitivity map
        csm.shape = [nr_coils, csm_mtx, csm_mtx]
        sensitivity_map = csm[:,crop_top-1:crop_bottom,crop_left-1:crop_right]
        #self.setData('debug', np.squeeze(sensitivity_map))

        # get sizes
        # number of channels n
        n = sensitivity_map.shape[-3]
        x_size = sensitivity_map.shape[-1]
        y_size = sensitivity_map.shape[-2]
        nr_pixels = x_size * y_size

        if compute:

            # noise covariance matrix Psi
            noise_cv_matrix = np.cov(noise)

            # Cholesky decomposition to determine T, where T Psi T_H = 1
            L = np.linalg.cholesky(noise_cv_matrix)
            T = np.linalg.inv(L)

            # decorrelated sensitivity map S_hat
            S_hat = np.zeros([nr_pixels, n], dtype=np.complex64)
            for x in range(x_size):
                for y in range(y_size):
                    index = y + x * y_size
                    S_hat[index, :] = np.dot(T, sensitivity_map[:,y,x])
                        
            self.log.debug("after S_hat")
            
            # P = sum of S_hat S_hat_pseudo_inverse over all pixels
            P = np.zeros([n,n], dtype=np.complex64)
            S_hat_matrix = np.zeros([n,1], dtype=np.complex64)
            for index in range(nr_pixels):
                # pseudo inverse of S_hat
                S_hat_matrix[:,0] = S_hat[index,:]
                S_hat_pinv = np.linalg.pinv(S_hat_matrix)
                P = P + np.dot(S_hat_matrix, S_hat_pinv)
            self.log.debug("after S_hat_pinv")
            

            # singular value decomposition of P
            # if P is symmetric and positive definite, the SVD is P = U d U.H instead of P = U d V.H
            U, d, V = np.linalg.svd(P)
            self.log.debug("after svd")

            # the transformation matrix A is then given by A = C U.H T
            # C is diagonal matrix with 1 on the first m rows and 0 in the remaining
            # instead of using C, only assing mxn to A
            C = np.array(np.zeros([n,n]), dtype=np.float32)
            self.log.debug("after C")
            for x in range(m):
                C[x,x]=1.
            A_square = np.dot(C, np.dot(U.T.conjugate(), T))
            A = A_square[0:m,:]
            self.log.debug("after A")

            # Compress the data
            if data.ndim == 5:
                out = np.zeros([m,extra_dim2,extra_dim1,nr_arms,nr_points],dtype=data.dtype)
                for extra2 in range(extra_dim2):
                    for extra1 in range(extra_dim1):
                        for arm in range(nr_arms):
                            for point in range(nr_points):
                                out[:,extra2,extra1,arm,point] = np.dot(A, data[:,extra2,extra1,arm,point])

            # SETTING PORT INFO
            self.setData('compressed data', np.squeeze(out))
            self.setData('A', A)
            self.setData('noise covariance', noise_cv_matrix)
    
            # end of compute
            if reset_compute_button:
                self.setAttr('compute', val=False)
    

        return 0
Ejemplo n.º 3
0
    def compute(self):
        import bni.gridding.Kaiser2D_utils as kaiser2D
        
        self.log.debug("Start CG SENSE 2D")
        # get port and widget inputs
        data = self.getData('data').astype(np.complex64, copy=False)
        coords = self.getData('coords').astype(np.float32, copy=False)
        weights = self.getData('weights').astype(np.float32, copy=False)
        
        mtx_original = self.getVal('mtx')
        iterations = self.getVal('iterations')
        step = self.getVal('step')
        oversampling_ratio = self.getVal('oversampling ratio')
        number_threads = self.getVal('number of threads')
        GA = self.getVal('Golden Angle - combine dynamics before gridding')
        
        # for a single iteration step use the csm stored in the out port
        if step and (self.getData('oversampled CSM') is not None):
            csm = self.getData('oversampled CSM')
        else:
            csm = self.getData('coil sensitivity')
        if csm is not None:
            csm = csm.astype(np.complex64, copy=False)
        
        # oversampling: Oversample at the beginning and crop at the end
        mtx = np.int(np.around(mtx_original * oversampling_ratio))
        if mtx%2:
            mtx+=1
        if oversampling_ratio > 1:
            mtx_min = np.int(np.around((mtx-mtx_original)/2))
            mtx_max = mtx_min + mtx_original
        else:
            mtx_min = 0
            mtx_max = mtx
                
        # data dimensions
        nr_points = data.shape[-1]
        nr_arms = data.shape[-2]
        nr_coils = data.shape[0]
        if data.ndim == 3:
            extra_dim1 = 1
            extra_dim2 = 1
            data.shape = [nr_coils,extra_dim2,extra_dim1,nr_arms,nr_points]
        elif data.ndim == 4:
            extra_dim1 = data.shape[-3]
            extra_dim2 = 1
            data.shape = [nr_coils,extra_dim2,extra_dim1,nr_arms,nr_points]
        elif data.ndim == 5:
            extra_dim1 = data.shape[-3]
            extra_dim2 = data.shape[-4]
        elif data.ndim > 5:
            self.log.warn("Not implemented yet")
        out_dims_grid = [nr_coils, extra_dim2, extra_dim1, mtx, nr_arms, nr_points]
        out_dims_degrid = [nr_coils, extra_dim2, extra_dim1, nr_arms, nr_points]
        out_dims_fft = [nr_coils, extra_dim2, extra_dim1, mtx, mtx]
        iterations_shape = [extra_dim2, extra_dim1, mtx, mtx]

        # coords dimensions: (add 1 dimension as they could have another dimension for golden angle dynamics
        if coords.ndim == 3:
            coords.shape = [1,nr_arms,nr_points,2]
            weights.shape = [1,nr_arms,nr_points]

        # output including all iterations
        x_iterations = np.zeros([iterations,extra_dim2,extra_dim1,mtx_original,mtx_original],dtype=np.complex64)
        if step and (iterations > 1):
            previous_iterations = self.getData('x iterations')
            previous_iterations.shape = [iterations-1,extra_dim2, extra_dim1, mtx_original, mtx_original]
            x_iterations[:-1,:,:,:,:] = previous_iterations

        # pre-calculate Kaiser-Bessel kernel
        self.log.debug("Calculate kernel")
        kernel_table_size = 800
        kernel = kaiser2D.kaiserbessel_kernel( kernel_table_size, oversampling_ratio)
        
        # pre-calculate the rolloff for the spatial domain
        
        roll = kaiser2D.rolloff2D(mtx, kernel)

        # for a single iteration step use the oversampled csm and intermediate results stored in outports
        if step and (self.getData('d') is not None):
            self.log.debug("Save some time and use the previously determined csm stored in the cropped CSM outport.")
        elif GA: #combine data from GA dynamics before gridding, use code from VirtualChannels_GPI.py
            # grid images for each phase - needs to be done at some point, not really here for csm though.
            self.log.debug("Grid undersampled data")
            gridded_kspace = kaiser2D.grid2D(data, coords, weights, kernel, out_dims_grid, number_threads=number_threads)
            # FFT
            image_domain = kaiser2D.fft2D(gridded_kspace, dir=0, out_dims_fft=out_dims_fft)
            # rolloff
            image_domain *= roll
            
            twoD_or_threeD = coords.shape[-1]
            # parameters from UI
            UI_width = self.getVal('Autocalibration Width (%)')
            UI_taper = self.getVal('Autocalibration Taper (%)')
            UI_mask_floor = self.getVal('Mask Floor (% of max mag)')
            mask_dilation = self.getVal('Autocalibration mask dilation [pixels]')
            UI_average_csm = self.getVal('Dynamic data - average all dynamics for csm')
            numiter = self.getVal('Autocalibration SDC Iterations')
            original_csm_mtx = np.int(0.01 * UI_width * mtx_original)
            
            is_GoldenAngle_data = True
            nr_arms_csm = self.getVal('# golden angle dynamics for csm')
            nr_all_arms_csm = extra_dim1 * nr_arms
            extra_dim1_csm = 1
            
            # coords dimensions: (add 1 dimension as they could have another dimension for golden angle dynamics
            if coords.ndim == 3:
                coords.shape = [1,nr_arms,nr_points,twoD_or_threeD]
            
            # create low resolution csm
            # cropping the data will make gridding and FFT much faster
            magnitude_one_interleave = np.zeros(nr_points)
            for x in range(nr_points):
                magnitude_one_interleave[x] = np.sqrt( coords[0,0,x,0]**2 + coords[0,0,x,1]**2)
            within_csm_width_radius = magnitude_one_interleave[:] < (0.01 * UI_width * 0.5) # for BNI spirals should be 0.45 instead of 0.5
            nr_points_csm_width = within_csm_width_radius.sum()
            # in case of radial trajectory, it doesn't start at zero..
            found_start_point = 0
            found_end_point = 0
            for x in range(nr_points):
                if ((not found_start_point) and (within_csm_width_radius[x])):
                    found_start_point = 1
                    start_point = x
                if ((not found_end_point) and (found_start_point) and (not within_csm_width_radius[x])):
                    found_end_point = 1
                    end_point = x
            if not found_end_point:
                end_point = nr_points
            self.log.node("Start and end points in interleave are: "+str(start_point)+" and "+str(end_point)+" leading to "+str(nr_points_csm_width)+" points for csm.")
            
            arm_counter = 0
            extra_dim1_counter = 0
            arm_with_data_counter = 0
            while (arm_with_data_counter < nr_arms_csm and extra_dim1_counter < extra_dim1):
                if (coords[extra_dim1_counter, arm_counter,0,0] != coords[extra_dim1_counter, arm_counter,-1,0]): #only equal when no data in this interleave during resorting
                    arm_with_data_counter += 1
                arm_counter += 1
                if arm_counter == nr_arms:
                    arm_counter = 0
                    extra_dim1_counter += 1
            self.log.node("Found "+str(arm_with_data_counter)+" arms, and was looking for "+str(nr_arms_csm)+" from a total of "+str(nr_all_arms_csm)+" arms.")
            
            csm_data = np.zeros([nr_coils,extra_dim2,extra_dim1_csm,arm_with_data_counter,nr_points_csm_width], dtype=data.dtype)
            csm_coords = np.zeros([1,arm_with_data_counter,nr_points_csm_width,twoD_or_threeD], dtype=coords.dtype)
            
            arm_counter = 0
            extra_dim1_counter = 0
            arm_with_data_counter = 0
            while (arm_with_data_counter < nr_arms_csm and extra_dim1_counter < extra_dim1):
                if (coords[extra_dim1_counter, arm_counter,0,0] != coords[extra_dim1_counter, arm_counter,-1,0]): #only equal when no data in this interleave during resorting
                    csm_data[:,:,0,arm_with_data_counter,:] = data[:,:,extra_dim1_counter,arm_counter,start_point:end_point]
                    csm_coords[0,arm_with_data_counter,:,:] = coords[extra_dim1_counter,arm_counter,start_point:end_point,:]
                    arm_with_data_counter += 1
                arm_counter += 1
                if arm_counter == nr_arms:
                    arm_counter = 0
                    extra_dim1_counter += 1
            self.log.node("Found "+str(arm_with_data_counter)+" arms, and was looking for "+str(nr_arms_csm)+" from a total of "+str(nr_all_arms_csm)+" arms.")
            
            # now set the dimension lists
            out_dims_grid_csm = [nr_coils, extra_dim2, extra_dim1_csm, mtx, arm_with_data_counter, nr_points_csm_width]
            out_dims_fft_csm = [nr_coils, extra_dim2, extra_dim1_csm, mtx, mtx]
            
            # generate SDC based on number of arms and nr of points being used for csm
            import core.gridding.sdc as sd
            #csm_weights = sd.twod_sdcsp(csm_coords.squeeze().astype(np.float64), numiter, 0.01 * UI_taper, mtx)
            cmtxdim = np.array([mtx,mtx],dtype=np.int64)
            wates = np.ones((arm_with_data_counter * nr_points_csm_width), dtype=np.float64)
            coords_for_sdc = csm_coords.astype(np.float64)
            coords_for_sdc.shape = [arm_with_data_counter * nr_points_csm_width, twoD_or_threeD]
            csm_weights = sd.twod_sdc(coords_for_sdc, wates, cmtxdim, numiter, 0.01 * UI_taper )
            csm_weights.shape = [1,arm_with_data_counter,nr_points_csm_width]
            
            # Grid
            gridded_kspace_csm = kaiser2D.grid2D(csm_data, csm_coords, csm_weights.astype(np.float32), kernel, out_dims_grid_csm, number_threads=number_threads)
            image_domain_csm = kaiser2D.fft2D(gridded_kspace_csm, dir=0, out_dims_fft=out_dims_fft_csm)
            # rolloff
            image_domain_csm *= roll
            # # crop to original matrix size
            # csm = image_domain_csm[...,mtx_min:mtx_max,mtx_min:mtx_max]
            # normalize by rms (better would be to use a whole body coil image
            csm_rms = np.sqrt(np.sum(np.abs(image_domain_csm)**2, axis=0))
            image_domain_csm = image_domain_csm / csm_rms
            # zero out points that are below mask threshold
            thresh = 0.01 * UI_mask_floor * csm_rms.max()
            mask = csm_rms > thresh
            # use scipy library to grow mask and fill holes.
            from scipy import ndimage
            mask.shape = [mtx,mtx]
            mask = ndimage.morphology.binary_dilation(mask, iterations=mask_dilation)
            mask = ndimage.binary_fill_holes(mask)
            
            image_domain_csm *= mask
            if extra_dim1 > 1:
                csm = np.zeros([nr_coils, extra_dim2, extra_dim1, mtx, mtx], dtype=image_domain_csm.dtype)
                for extra_dim1_counter in range(extra_dim1):
                    csm[:,:,extra_dim1_counter,:,:]=image_domain_csm[:,:,0,:,:]
            else:
                csm = image_domain_csm
            self.setData('oversampled CSM', csm)
            self.setData('cropped CSM', csm[...,mtx_min:mtx_max,mtx_min:mtx_max])
            
        else: # this is the normal path (not single iteration step)
            # grid to create images that are corrupted by
            # aliasing due to undersampling.  If the k-space data have an
            # auto-calibration region, then this can be used to generate B1 maps.
            self.log.debug("Grid undersampled data")
            gridded_kspace = kaiser2D.grid2D(data, coords, weights, kernel, out_dims_grid, number_threads=number_threads)
            # FFT
            image_domain = kaiser2D.fft2D(gridded_kspace, dir=0, out_dims_fft=out_dims_fft)
            # rolloff
            image_domain *= roll

            # calculate auto-calibration B1 maps
            if csm is None:
                self.log.debug("Generating autocalibrated B1 maps...")
                # parameters from UI
                UI_width = self.getVal('Autocalibration Width (%)')
                UI_taper = self.getVal('Autocalibration Taper (%)')
                UI_mask_floor = self.getVal('Mask Floor (% of max mag)')
                UI_average_csm = self.getVal('Dynamic data - average all dynamics for csm')
                csm = kaiser2D.autocalibrationB1Maps2D(image_domain, taper=UI_taper, width=UI_width, mask_floor=UI_mask_floor, average_csm=UI_average_csm)
            else:
                # make sure input csm and data are the same mtx size.
                # Assuming the FOV was the same: zero-fill in k-space
                if csm.ndim != 5:
                    self.log.debug("Reshape imported csm")
                    csm.shape = [nr_coils,extra_dim2,extra_dim1,csm.shape[-2],csm.shape[-1]]
                if csm.shape[-1] != mtx:
                    self.log.debug("Interpolate csm to oversampled matrix size")
                    csm_oversampled_mtx = np.int(csm.shape[-1] * oversampling_ratio)
                    if csm_oversampled_mtx%2:
                        csm_oversampled_mtx+=1
                    out_dims_oversampled_image_domain = [nr_coils, extra_dim2, extra_dim1, csm_oversampled_mtx, csm_oversampled_mtx]
                    csm = kaiser2D.fft2D(csm, dir=1, out_dims_fft=out_dims_oversampled_image_domain)
                    csm = kaiser2D.fft2D(csm, dir=0, out_dims_fft=out_dims_fft)
            self.setData('oversampled CSM', csm)
            self.setData('cropped CSM', csm[...,mtx_min:mtx_max,mtx_min:mtx_max])

        # keep a conjugate csm set on hand
        csm_conj = np.conj(csm)

        ## Iteration 1:
        if step and (self.getData('d') is not None):
            self.log.debug("\tSENSE Iteration: " + str(iterations))
            # make sure the loop doesn't start if only one step is needed
            iterations = 0

            # Get the data from the last execution of this node for an
            # additional single iteration.
            d = self.getData('d').copy()
            r = self.getData('r').copy()
            x = self.getData('x').copy()

            # A
            Ad = csm * d # add coil phase
            Ad *= roll # pre-rolloff for degrid convolution
            Ad = kaiser2D.fft2D(Ad, dir=1)
            Ad = kaiser2D.degrid2D(Ad, coords, kernel, out_dims_degrid, number_threads=number_threads, oversampling_ratio = oversampling_ratio)
            Ad = kaiser2D.grid2D(Ad, coords, weights, kernel, out_dims_grid, number_threads=number_threads)
            Ad = kaiser2D.fft2D(Ad, dir=0)
            Ad *= roll
            Ad = csm_conj * Ad # broadcast multiply to remove coil phase
            Ad = Ad.sum(axis=0) # assume the coil dim is the first
        else:
            self.log.debug("\tSENSE Iteration: 1")
            # calculate initial conditions
            # d_0
            d_0 = csm_conj * image_domain # broadcast multiply to remove coil phase
            d_0 = d_0.sum(axis=0) # assume the coil dim is the first

            # Ad_0:
            #   degrid -> grid (loop over coils)
            Ad_0 = csm * d_0 # add coil phase
            Ad_0 *= roll # pre-rolloff for degrid convolution
            Ad_0 = kaiser2D.fft2D(Ad_0, dir=1)
            Ad_0 = kaiser2D.degrid2D(Ad_0, coords, kernel, out_dims_degrid, number_threads=number_threads, oversampling_ratio = oversampling_ratio)
            Ad_0 = kaiser2D.grid2D(Ad_0, coords, weights, kernel, out_dims_grid, number_threads=number_threads)
            Ad_0 = kaiser2D.fft2D(Ad_0, dir=0)
            Ad_0 *= roll
            Ad_0 = csm_conj * Ad_0 # broadcast multiply to remove coil phase
            Ad_0 = Ad_0.sum(axis=0) # assume the coil dim is the first
            
            # use the initial conditions for the first iter
            r = d = d_0
            x = np.zeros_like(d)
            Ad = Ad_0

        # CG - iter 1 or step
        d_last, r_last, x_last = self.do_cg(d, r, x, Ad)
        
        current_iteration = x_last.copy()
        current_iteration.shape = iterations_shape
        if step:
            x_iterations[-1,:,:,:,:] = current_iteration[...,mtx_min:mtx_max,mtx_min:mtx_max]
        else:
            x_iterations[0,:,:,:,:] = current_iteration[...,mtx_min:mtx_max,mtx_min:mtx_max]

        ## Iterations >1:
        for i in range(iterations-1):
            self.log.debug("\tSENSE Iteration: " + str(i+2))

            # input the result of the last iter
            d = d_last
            r = r_last
            x = x_last

            # A
            Ad = csm * d # add coil phase
            Ad *= roll # pre-rolloff for degrid convolution
            Ad = kaiser2D.fft2D(Ad, dir=1)
            Ad = kaiser2D.degrid2D(Ad, coords, kernel, out_dims_degrid, number_threads=number_threads, oversampling_ratio = oversampling_ratio)
            Ad = kaiser2D.grid2D(Ad, coords, weights, kernel, out_dims_grid, number_threads=number_threads)
            Ad = kaiser2D.fft2D(Ad, dir=0)
            Ad *= roll
            Ad = csm_conj * Ad # broadcast multiply to remove coil phase
            Ad = Ad.sum(axis=0) # assume the coil dim is the first
            # CG
            d_last, r_last, x_last = self.do_cg(d, r, x, Ad)

            current_iteration = x_last.copy()
            current_iteration.shape = iterations_shape
            x_iterations[i+1,:,:,:,:] = current_iteration[..., mtx_min:mtx_max, mtx_min:mtx_max]

        # return the final image     
        self.setData('d', d_last)
        self.setData('r', r_last)
        self.setData('x', x_last)
        self.setData('out', np.squeeze(current_iteration[..., mtx_min:mtx_max, mtx_min:mtx_max]))
        self.setData('x iterations', np.squeeze(x_iterations))

        return 0