Example #1
0
    def test_sigpy_cupy(self):
        import sigpy as sp

        assert Device(0) == sp.Device(0)

        device = Device(0)
        assert device.spdevice == sp.Device(0)
Example #2
0
    def __init__(self,
                 ksp,
                 calib_width=24,
                 thresh=0.01,
                 kernel_width=6,
                 crop=0.8,
                 max_iter=100,
                 device=sp.cpu_device,
                 output_eigenvalue=False,
                 show_pbar=True):
        self.device = sp.Device(device)
        self.output_eigenvalue = output_eigenvalue
        self.crop = crop

        img_ndim = ksp.ndim - 1
        num_coils = len(ksp)
        with sp.get_device(ksp):
            # Get calibration region
            calib_shape = [num_coils] + [calib_width] * img_ndim
            calib = sp.resize(ksp, calib_shape)
            calib = sp.to_device(calib, device)

        xp = self.device.xp
        with self.device:
            # Get calibration matrix
            kernel_shape = [num_coils] + [kernel_width] * img_ndim
            kernel_strides = [1] * (img_ndim + 1)
            mat = sp.array_to_blocks(calib, kernel_shape, kernel_strides)
            mat = mat.reshape([-1, sp.prod(kernel_shape)])

            # Perform SVD on calibration matrix
            _, S, VH = xp.linalg.svd(mat, full_matrices=False)
            VH = VH[S > thresh * S.max(), :]

            # Get kernels
            num_kernels = len(VH)
            kernels = VH.reshape([num_kernels] + kernel_shape)
            img_shape = ksp.shape[1:]

            # Get covariance matrix in image domain
            AHA = xp.zeros(img_shape[::-1] + (num_coils, num_coils),
                           dtype=ksp.dtype)
            for kernel in kernels:
                img_kernel = sp.ifft(sp.resize(kernel, ksp.shape),
                                     axes=range(-img_ndim, 0))
                aH = xp.expand_dims(img_kernel.T, axis=-1)
                a = xp.conj(aH.swapaxes(-1, -2))
                AHA += aH @ a

            AHA *= (sp.prod(img_shape) / kernel_width**img_ndim)
            self.mps = xp.ones(ksp.shape[::-1] + (1, ), dtype=ksp.dtype)

            alg = sp.alg.PowerMethod(
                lambda x: AHA @ x,
                self.mps,
                norm_func=lambda x: xp.sum(
                    xp.abs(x)**2, axis=-2, keepdims=True)**0.5,
                max_iter=max_iter)

        super().__init__(alg, show_pbar=show_pbar)
Example #3
0
def pipe_menon_dcf(coord,
                   device=sp.cpu_device,
                   max_iter=30,
                   n=128,
                   beta=8,
                   width=4,
                   show_pbar=True):
    r"""Compute Pipe Menon density compensation factor.

    Perform the following iteration:

    .. math::

        w = \frac{w}{|G^H G w|}

    with :math:`G` as the gridding operator.

    Args:
        coord (array): k-space coordinates.
        device (Device): computing device.
        max_iter (int): number of iterations.
        n (int): Kaiser-Bessel sampling numbers for gridding operator.
        beta (float): Kaiser-Bessel kernel parameter.
        width (float): Kaiser-Bessel kernel width.
        show_pbar (bool): show progress bar.

    Returns:
        array: density compensation factor.

    References:
        Pipe, James G., and Padmanabhan Menon.
        Sampling Density Compensation in MRI:
        Rationale and an Iterative Numerical Solution.
        Magnetic Resonance in Medicine 41, no. 1 (1999): 179–86.


    """
    device = sp.Device(device)
    xp = device.xp

    with device:
        w = xp.ones(coord.shape[:-1], dtype=coord.dtype)
        img_shape = sp.estimate_shape(coord)

        # Get kernel
        x = xp.arange(n, dtype=coord.dtype) / n
        kernel = xp.i0(beta * (1 - x**2)**0.5).astype(coord.dtype)
        kernel /= kernel.max()

        G = sp.linop.Gridding(img_shape, coord, width, kernel)
        with tqdm(total=max_iter, disable=not show_pbar) as pbar:
            for it in range(max_iter):
                GHGw = G.H * G * w
                w /= xp.abs(GHGw)
                resid = xp.abs(GHGw - 1).max().item()

                pbar.set_postfix(resid='{0:.2E}'.format(resid))
                pbar.update()

    return w
Example #4
0
File: app.py Project: jtamir/sigpy
    def __init__(self, input, output, batch_size, mu,
                 lamda=0,
                 max_iter=100, max_inner_iter=100, device=sp.cpu_device,
                 checkpoint_path=None):
        dtype = output.dtype

        num_data = len(output)
        num_batches = num_data // batch_size
        self.device = device
        self.lamda = lamda
        self.batch_size = batch_size
        self.input = input
        self.output = output
        self.checkpoint_path = checkpoint_path

        self.device = sp.Device(device)
        xp = self.device.xp
        with self.device:
            self.mat = xp.zeros(input.shape[1:] + output.shape[1:], dtype=dtype)
            self.input_t = xp.empty((batch_size, ) + input.shape[1:], dtype=dtype)
            self.output_t = xp.empty((batch_size, ) + output.shape[1:], dtype=dtype)
            self.t_idx = sp.ShuffledNumbers(num_batches)
        
        self._get_A()
        def proxf(mu, x):
            return sp.app.LinearLeastSquares(self.A, self.output_t, x=x,
                                             lamda=self.lamda / num_batches,
                                             mu=1 / mu, z=x,
                                             max_iter=max_inner_iter).run()

        alg = sp.alg.ProximalPointMethod(proxf, mu, self.mat, max_iter=max_iter)
        super().__init__(alg)
Example #5
0
    def __init__(self,
                 y,
                 mps_ker_width=16,
                 ksp_calib_width=24,
                 lamda=0,
                 device=sp.cpu_device,
                 comm=None,
                 weights=None,
                 coord=None,
                 max_iter=10,
                 max_inner_iter=10,
                 normalize=True,
                 show_pbar=True):
        self.y = y
        self.mps_ker_width = mps_ker_width
        self.ksp_calib_width = ksp_calib_width
        self.lamda = lamda
        self.weights = weights
        self.coord = coord
        self.max_iter = max_iter
        self.max_inner_iter = max_inner_iter
        self.normalize = normalize

        self.device = sp.Device(device)
        self.comm = comm
        self.dtype = y.dtype
        self.num_coils = len(y)
        if comm is not None:
            show_pbar = show_pbar and comm.rank == 0

        self._get_data()
        self._get_vars()
        self._get_alg()
        super().__init__(self.alg, show_pbar=show_pbar)
Example #6
0
 def spdevice(self):
     """sigpy.Device: The equivalent ```sigpy.Device```."""
     if not env.sigpy_available():
         raise RuntimeError("`sigpy` not installed.")
     if self.id >= 0 and self.type != "cuda":
         raise RuntimeError(
             f"sigpy.Device does not support type {self.type}")
     return sp.Device(self.id)
    def __init__(self,
                 ksp,
                 coord,
                 dcf,
                 mps,
                 T,
                 lamda,
                 blk_widths=[32, 64, 128],
                 alpha=1,
                 beta=0.5,
                 sgw=None,
                 device=sp.cpu_device,
                 comm=None,
                 seed=0,
                 max_epoch=60,
                 decay_epoch=20,
                 max_power_iter=5,
                 show_pbar=True):
        self.ksp = ksp
        self.coord = coord
        self.dcf = dcf
        self.mps = mps
        self.sgw = sgw
        self.blk_widths = blk_widths
        self.T = T
        self.lamda = lamda
        self.alpha = alpha
        self.beta = beta
        self.device = sp.Device(device)
        self.comm = comm
        self.seed = seed
        self.max_epoch = max_epoch
        self.decay_epoch = decay_epoch
        self.max_power_iter = max_power_iter
        self.show_pbar = show_pbar and (comm is None or comm.rank == 0)

        np.random.seed(self.seed)
        self.xp = self.device.xp
        with self.device:
            self.xp.random.seed(self.seed)

        self.dtype = self.ksp.dtype
        self.C, self.num_tr, self.num_ro = self.ksp.shape
        self.tr_per_frame = self.num_tr // self.T
        self.img_shape = self.mps.shape[1:]
        self.D = len(self.img_shape)
        self.J = len(self.blk_widths)
        if self.sgw is not None:
            self.dcf *= np.expand_dims(self.sgw, -1)

        self.B = [self._get_B(j) for j in range(self.J)]
        self.G = [self._get_G(j) for j in range(self.J)]

        self._normalize()
Example #8
0
def jsens_calib(ksp, coord, dcf, ishape, device = sp.Device(-1)):
    img_s = nft.nufft_adj([ksp],[coord],[dcf],device = device,ishape = ishape,id_channel =True)
    ksp = sp.fft(input=np.asarray(img_s[0]),axes=(1,2,3))
    mps = mr.app.JsenseRecon(ksp,
                             mps_ker_width=12,
                             ksp_calib_width=32,
                             lamda=0,
                             device=device,
                             comm=sp.Communicator(),
                             max_iter=10,
                             max_inner_iter=10).run()
    return mps
Example #9
0
def check_linop_adjoint(A, dtype=np.float, device=sp.cpu_device):

    device = sp.Device(device)
    x = sp.randn(A.ishape, dtype=dtype, device=device)
    y = sp.randn(A.oshape, dtype=dtype, device=device)

    xp = device.xp
    with device:
        lhs = xp.vdot(A * x, y)
        rhs = xp.vdot(x, A.H * y)

        xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5)
Example #10
0
def nufft1(img,
           traj,
           device=sp.Device(-1),
           smap=None,
           dcf=None,
           batch=40000,
           id_channel=False):
    xp = device.xp
    N_phase = img.shape[0]
    ksp = []
    for num_ph in range(N_phase):
        img_t = img[num_ph]
        coord_t = traj[num_ph]
        if dcf is not None:
            dcf_t = dcf[num_ph]
        img_shape = sp.estimate_shape(coord_t)
        num_tr, num_ro, ndim = coord_t.shape
        if id_channel is True:
            num_coils = img_t.shape[0]
        else:
            if smap is None:
                smap = np.array([1])
                num_coils = 1
            else:
                num_coils = smap.shape[0]

        with device:
            ksp_t = np.zeros((num_coils, num_tr, num_ro), dtype=np.complex64)
            for c in range(num_coils):
                if id_channel is True:
                    img_tc = sp.to_device(img_t[c, ...], device)
                else:
                    img_tc = sp.to_device(img_t * smap[c, ...], device)
                for seg in range((num_tr - 1) // batch + 1):
                    coord_tt = sp.to_device(
                        coord_t[seg * batch:np.minimum((seg + 1) *
                                                       batch, num_tr), ...],
                        device)
                    ksp_tt = sp.nufft(img_tc, coord_tt)
                    if dcf is not None:
                        dcf_tt = sp.to_device(
                            dcf_t[seg * batch:np.minimum((seg + 1) *
                                                         batch, num_tr), ...],
                            device)
                        ksp_tt = dcf_tt * ksp_tt

                    ksp_t[c, seg * batch:np.minimum((seg + 1) * batch, num_tr),
                          ...] = sp.to_device(ksp_tt)

        ksp.append(ksp_t)

    return np.asarray(ksp)
def process_slice(kspace, args, calib_method='jsense'):
    # get data dimensions
    nky, nkz, nechoes, ncoils = kspace.shape

    # ESPIRiT parameters
    nmaps = args.num_emaps
    calib_size = args.ncalib
    crop_value = args.crop_value

    if args.device is -1:
        device = sp.cpu_device
    else:
        device = sp.Device(args.device)

    # compute sensitivity maps (BART)
    #cmd = f'ecalib -d 0 -S -m {nmaps} -c {crop_value} -r {calib_size}'
    #maps = bart.bart(1, cmd, kspace[:,:,0,None,:])
    #maps = np.reshape(maps, (nky, nkz, 1, ncoils, nmaps))

    # compute sensitivity maps (SigPy)
    ksp = np.transpose(kspace[:, :, 0, :], [2, 1, 0])
    if calib_method is 'espirit':
        maps = app.EspiritCalib(ksp,
                                calib_width=calib_size,
                                crop=crop_value,
                                device=device,
                                show_pbar=False).run()
    elif calib_method is 'jsense':
        maps = app.JsenseRecon(ksp,
                               mps_ker_width=6,
                               ksp_calib_width=calib_size,
                               device=device,
                               show_pbar=False).run()
    else:
        raise ValueError('%s calibration method not implemented...' %
                         calib_method)
    maps = np.reshape(np.transpose(maps, [2, 1, 0]),
                      (nky, nkz, 1, ncoils, nmaps))

    # Convert everything to tensors
    kspace_tensor = cplx.to_tensor(kspace).unsqueeze(0)
    maps_tensor = cplx.to_tensor(maps).unsqueeze(0)

    # Do coil combination using sensitivity maps (PyTorch)
    A = T.SenseModel(maps_tensor)
    im_tensor = A(kspace_tensor, adjoint=True)

    # Convert tensor back to numpy array
    image = cplx.to_numpy(im_tensor.squeeze(0))

    return image, maps
Example #12
0
def autofov(ksp, coord, dcf, num_ro=100, device=sp.cpu_device, thresh=0.1):
    """Automatic estimation of FOV.

    FOV is estimated by thresholding a low resolution gridded image.
    coord will be modified in-place.

    Args:
        ksp (array): k-space measurements of shape (C, num_tr, num_ro, D).
            where C is the number of channels,
            num_tr is the number of TRs, num_ro is the readout points,
            and D is the number of spatial dimensions.
        coord (array): k-space coordinates of shape (num_tr, num_ro, D).
        dcf (array): density compensation factor of shape (num_tr, num_ro).
        num_ro (int): number of read-out points.
        device (Device): computing device.
        thresh (float): threshold between 0 and 1.

    """
    device = sp.Device(device)
    xp = device.xp
    with device:
        kspc = ksp[:, :, :num_ro]
        coordc = coord[:, :num_ro, :]
        dcfc = dcf[:, :num_ro]
        coordc2 = sp.to_device(coordc * 2, device)

        num_coils = len(kspc)
        imgc_shape = np.array(sp.estimate_shape(coordc))
        imgc2_shape = sp.estimate_shape(coordc2)
        imgc2_center = [i // 2 for i in imgc2_shape]
        imgc2 = sp.nufft_adjoint(sp.to_device(dcfc * kspc, device), coordc2,
                                 [num_coils] + imgc2_shape)
        imgc2 = xp.sum(xp.abs(imgc2)**2, axis=0)**0.5
        if imgc2.ndim == 3:
            imgc2_cor = imgc2[:, imgc2.shape[1] // 2, :]
            thresh *= imgc2_cor.max()
        else:
            thresh *= imgc2.max()

        boxc = imgc2 > thresh

        boxc = sp.to_device(boxc)
        boxc_idx = np.nonzero(boxc)
        boxc_shape = np.array([
            int(np.abs(boxc_idx[i] - imgc2_center[i]).max()) * 2
            for i in range(imgc2.ndim)
        ])

        img_scale = boxc_shape / imgc_shape
        coord *= img_scale
Example #13
0
File: app.py Project: jtamir/sigpy
    def _get_params(self):
        self.device = sp.Device(self.device)
        self.dtype = self.y.dtype
        self.num_data = len(self.y)
        self.filt_width = self.L.shape[-1]
        self.num_filters = self.L.shape[self.multi_channel]
        self.data_ndim = self.y.ndim - self.multi_channel - 1

        if self.mode == 'full':
            self.R_shape = [self.num_data, self.num_filters] + [i - self.filt_width + 1
                                                                for i in self.y.shape[-self.data_ndim:]]
        else:
            self.R_shape = [self.num_data, self.num_filters] + [i + self.filt_width - 1
                                                                for i in self.y.shape[-self.data_ndim:]]
Example #14
0
def interp(I, M_field, device=sp.Device(-1), k_id=1, deblur=True):
    # b spline interpolation
    N = 64
    if k_id is 0:
        kernel = [(3 * (x / N)**3 - 6 * (x / N)**2 + 4) / 6
                  for x in range(0, N)] + [(2 - x / N)**3 / 6
                                           for x in range(N, 2 * N)]
        dkernel = np.array([-.2, 1.4, -.2])

        k_wid = 4
    else:
        kernel = [1 - x / (2 * N) for x in range(0, 2 * N)]
        dkernel = np.array([0, 1, 0])
        deblur = False
        k_wid = 2
    kernel = np.asarray(kernel)

    c_device = sp.get_device(I)
    ndim = M_field.shape[-1]

    # 2d/3d
    if ndim is 3:
        dkernel = dkernel[:, None, None] * dkernel[None, :,
                                                   None] * dkernel[None,
                                                                   None, :]
        Nx, Ny, Nz = I.shape
        my, mx, mz = np.meshgrid(np.arange(Ny), np.arange(Nx), np.arange(Nz))
        m = np.stack((mx, my, mz), axis=-1)
        M_field = M_field + m
    else:
        dkernel = dkernel[:, None] * dkernel[None, :]
        Nx, Ny = I.shape
        my, mx = np.meshgrid(np.arange(Ny), np.arange(Nx))
        m = np.stack((mx, my, mz), axis=-1)
        M_field = M_field + m
    # TODO remove out of range values

    # image warp

    g_device = device
    I = sp.to_device(input=I, device=g_device)
    I = sp.interp.interpolate(I, k_wid, kernel, M_field.astype(np.float64))
    # deconv
    if deblur is True:
        sp.conv.convolve(I, dkernel)
    I = sp.to_device(input=I, device=c_device)

    return I
    def __init__(self,
                 ksp,
                 coord,
                 dcf,
                 mps,
                 resp,
                 B,
                 lamda=1e-6,
                 alpha=1,
                 beta=0.5,
                 max_power_iter=10,
                 max_iter=300,
                 device=sp.cpu_device,
                 margin=10,
                 coil_batch_size=None,
                 comm=None,
                 show_pbar=True,
                 **kwargs):
        self.B = B
        self.C = len(mps)
        self.mps = mps
        self.device = sp.Device(device)
        self.xp = device.xp
        self.alpha = alpha
        self.beta = beta
        self.lamda = lamda
        self.max_iter = max_iter
        self.max_power_iter = max_power_iter
        self.comm = comm
        if comm is not None:
            self.show_pbar = show_pbar and comm.rank == 0

        self.img_shape = list(mps.shape[1:])

        bins = np.percentile(resp, np.linspace(0 + margin, 100 - margin,
                                               B + 1))
        self.bksp = []
        self.bcoord = []
        self.bdcf = []
        for b in range(B):
            idx = (resp >= bins[b]) & (resp < bins[b + 1])
            self.bksp.append(sp.to_device(ksp[:, idx], self.device))
            self.bcoord.append(sp.to_device(coord[idx], self.device))
            self.bdcf.append(sp.to_device(dcf[idx], self.device))

        self._normalize()
Example #16
0
def gridding_recon(ksp, coord, dcf, T=1, device=sp.cpu_device):
    """ Gridding reconstruction.

    Args:
        ksp (array): k-space measurements of shape (C, num_tr, num_ro, D).
            where C is the number of channels,
            num_tr is the number of TRs, num_ro is the readout points,
            and D is the number of spatial dimensions.
        coord (array): k-space coordinates of shape (num_tr, num_ro, D).
        dcf (array): density compensation factor of shape (num_tr, num_ro).
        mps (array): sensitivity maps of shape (C, N_D, ..., N_1).
            where (N_D, ..., N_1) represents the image shape.
        T (int): number of frames.

    Returns:
        img (array): image of shape (T, N_D, ..., N_1).
    """
    device = sp.Device(device)
    xp = device.xp
    num_coils, num_tr, num_ro = ksp.shape
    tr_per_frame = num_tr // T
    img_shape = sp.estimate_shape(coord)

    with device:
        img = []
        for t in range(T):
            tr_start = t * tr_per_frame
            tr_end = (t + 1) * tr_per_frame
            coord_t = sp.to_device(coord[tr_start:tr_end], device)
            dcf_t = sp.to_device(dcf[tr_start:tr_end], device)

            img_t = 0
            for c in range(num_coils):
                logging.info(f'Reconstructing time {t}, coil {c}')
                ksp_tc = sp.to_device(ksp[c, tr_start:tr_end, :], device)

                img_t += xp.abs(
                    sp.nufft_adjoint(ksp_tc * dcf_t, coord_t, img_shape))**2

            img_t = img_t**0.5
            img.append(sp.to_device(img_t))

    img = np.stack(img)
    return img
Example #17
0
File: app.py Project: jtamir/sigpy
    def _get_params(self):
        self.device = sp.Device(self.device)
        self.dtype = self.y.dtype
        self.data_ndim = self.y.ndim - self.multi_channel - 1
        if self.checkpoint_path is not None:
            self.checkpoint_path = pathlib.Path(self.checkpoint_path)
            self.checkpoint_path.mkdir(parents=True, exist_ok=True)

        self.batch_size = min(len(self.y), self.batch_size)
        self.num_batches = len(self.y) // self.batch_size

        self.L_shape = [self.num_filters] + [self.filt_width] * self.data_ndim
        if self.multi_channel:
            self.L_shape = [self.y.shape[1]] + self.L_shape

        if self.mode == 'full':
            self.R_t_shape = [self.batch_size, self.num_filters] + [i - self.filt_width + 1
                                                                    for i in self.y.shape[-self.data_ndim:]]
        else:
            self.R_t_shape = [self.batch_size, self.num_filters] + [i + self.filt_width - 1
                                                                    for i in self.y.shape[-self.data_ndim:]]
Example #18
0
def init_bloch_vector(shape=None, dtype=np.complex, device=sp.cpu_device):
    """Initialize magnetization in Bloch vector representation.

    Args:
        shape (tuple): batch shape.
        dtype (Dtype): data type.
        device (Device): device.

    Returns:
        array: magnetization array.

    """
    device = sp.Device(device)
    xp = device.xp
    with device:
        if shape is None:
            shape = []

        m = xp.zeros(list(shape) + [3], dtype=dtype)
        m[..., 2] = 1
        return m
Example #19
0
def init_density_matrix(shape=None, dtype=np.complex, device=sp.cpu_device):
    """Initialize magnetization in density matrix representation.

    Args:
        shape (tuple): batch shape.
        dtype (Dtype): data type.
        device (Device): device.

    Returns:
        array: magnetization array.

    """
    device = sp.Device(device)
    xp = device.xp
    with device:
        if shape is None:
            shape = []

        p = xp.zeros(list(shape) + [2, 2], dtype=dtype)
        p[..., 0, 0] = 1
        return p
Example #20
0
def NFTs(ishape, coord, device = sp.Device(-1)):
    n_Channel = ishape[0]
    oshape = list((n_Channel,)) + list(coord.shape[:-1])
    
    NFT = sp.linop.NUFFT(ishape[1:], coord=coord)
    NFTs = Diags([DLD(NFT,device=device) for i in range(n_Channel)],oshape,ishape)

#     B1 = sp.linop.ToDevice(NFT.ishape,idevice=sp.Device(-1),odevice=device)
#     B2 = sp.linop.ToDevice(NFT.oshape,idevice=sp.Device(-1),odevice=device)
#     NFTs = Diags([B2.H*NFT*B1 for i in range(n_Channel)],oshape,ishape)
#     i_vec_len = 1
#     for tmp in ishape:
#         i_vec_len = i_vec_len * tmp
#     o_vec_len = 1
#     for tmp in oshape:
#         o_vec_len = o_vec_len * tmp
    
#     NFTs = sp.linop.Diag([B2.H*NFT*B1 for i in range(n_Channel)])
#     R1 = sp.linop.Reshape(oshape=(o_vec_len,),ishape=oshape)
#     R2 = sp.linop.Reshape(oshape=(i_vec_len,),ishape=ishape)
#     NFTs = R1.H*NFTs*R2
    
    return NFTs
Example #21
0
def DLD(Linop, device = sp.Device(-1)):
    B1 = sp.linop.ToDevice(Linop.ishape,idevice=sp.Device(-1),odevice=device)
    B2 = sp.linop.ToDevice(Linop.oshape,idevice=sp.Device(-1),odevice=device)
    Linop = B2.H*Linop*B1
    return Linop
Example #22
0
    def __init__(self, rdr, tbl, mps, psf, phi, spr='W', cft=True, \
                    lmb=1e-5, mit=30, alp=0.25, tol=1e-3, dev=-1):
        self.cpu = -1
        self.max_dims = 8
        self.device = sp.Device(dev)
        self.xp = self.device.xp
        self.center = cft

        with self.device:
            self.rdr = self.xp.array(rdr).astype(self.xp.int32)
            self.tbl = self.xp.array(tbl)
            self.tbl = self.tbl / self.xp.max(self.xp.abs(self.tbl))
            self.mps = self.xp.array(self._broadcast_check(mps))
            self.psf = self.xp.array(self._broadcast_check(psf))
            self.phi = self.xp.array(self._broadcast_check(phi))
            self.lmb = lmb  # Lambda.
            self.mit = mit  # Max-Iter
            self.alp = alp  # Step size.
            self.tol = tol  # Tolerance.

            self.wx = self.psf.shape[7]
            self.sx = self.mps.shape[7]
            self.sy = self.mps.shape[6]
            self.sz = self.mps.shape[5]
            self.nc = self.mps.shape[4]
            self.md = self.mps.shape[3]
            self.tf = self.phi.shape[2]
            self.tk = self.phi.shape[1]

            self.net_acceleration = (self.tf * self.sy *
                                     self.sz) / self.rdr.shape[0]

            assert (self.md == 1
                    )  # Until multiple ESPIRiT maps is implemented.

            self.S = None
            if (spr == 'W'):
                wavelet_axes = tuple(
                    [k for k in range(5, 8) if self.mps.shape[k] > 1])
                self.S = sp.linop.Wavelet(
                    [1, self.tk, 1, self.md, 1, self.sz, self.sy, self.sx],
                    axes=wavelet_axes)
            elif (spr == 'T'):
                self.S = sp.linop.FiniteDifference(
                    [1, self.tk, 1, self.md, 1, self.sz, self.sy, self.sx],
                    axes=(5, 6, 7))
            else:
                self.S = sp.linop.Identity(
                    [1, self.tk, 1, self.md, 1, self.sz, self.sy, self.sx])

            self.E = sp.linop.Multiply(
                [1, self.tk, 1, self.md, 1, self.sz, self.sy, self.sx],
                self.mps)
            self.R      = sp.linop.Resize([1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], \
                                          [1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.sx])
            self.Fx     = sp.linop.FFT([1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], axes=(7,), \
                            center=self.center)
            self.Psf = sp.linop.Multiply(
                [1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx],
                self.psf)
            self.Fyz    = sp.linop.FFT([1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], axes=(5, 6), \
                            center=self.center)

            self.K = None
            self._construct_kernel()
            self.K    = sp.linop.Reshape( [      1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx],              \
                                          [         self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx])            * \
                        sp.linop.Sum(     [self.tk, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], axes=(1,)) * \
                        sp.linop.Multiply([      1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], self.kernel)

            self._construct_AHb()
            self.res = 0 * self.AHb
            self.AHA   = self.S * self.E.H * self.R.H * self.Fx.H * self.Psf.H * self.Fyz.H * self.K * \
                         self.Fyz * self.Psf * self.Fx * self.R * self.E * self.S.H
            self.mxevl = sp.app.MaxEig(self.AHA, self.xp.complex64,
                                       device=dev).run()
            self.alp = self.alp / self.mxevl
            self.gradf = lambda x: self.AHA(x) - self.AHb
            self.proxg = sp.prox.L1Reg(self.res.shape, lmb)

            alg = sp.alg.GradientMethod(self.gradf,
                                        self.res,
                                        self.alp,
                                        proxg=self.proxg,
                                        accelerate=True,
                                        tol=self.tol,
                                        max_iter=self.mit)
            super().__init__(alg)
Example #23
0
 def use_device(self, device):
     self.device = sp.Device(device)
Example #24
0
def nufft_adj1(data,
               traj,
               dcf,
               device=sp.Device(-1),
               smap=None,
               batch=40000,
               id_channel=False,
               ishape=None):
    xp = device.xp
    N_phase = data.shape[0]
    img = []
    for num_ph in range(N_phase):
        ksp_t = data[num_ph]
        coord_t = traj[num_ph]
        dcf_t = dcf[num_ph]

        if smap is not None:
            img_shape = list(smap.shape[-3:])
        elif ishape is not None:
            img_shape = ishape
        else:
            img_shape = sp.estimate_shape(coord_t)

        num_coils, num_tr, num_ro = ksp_t.shape
        ndim = coord_t.shape[-1]
        if id_channel is True:
            img_t = np.zeros((num_coils, ) + tuple(img_shape),
                             dtype=np.complex64)
        else:
            img_t = 0

        with device:
            for c in range(num_coils):
                img_tt = 0
                for seg in range((num_tr - 1) // batch + 1):
                    ksp_ttc = sp.to_device(
                        ksp_t[c, seg * batch:np.minimum((seg + 1) *
                                                        batch, num_tr), ...],
                        device)
                    coord_tt = sp.to_device(
                        coord_t[seg * batch:np.minimum((seg + 1) *
                                                       batch, num_tr), ...],
                        device)
                    dcf_tt = sp.to_device(
                        dcf_t[seg * batch:np.minimum((seg + 1) *
                                                     batch, num_tr), ...],
                        device)
                    img_tt += sp.nufft_adjoint(ksp_ttc * dcf_tt, coord_tt,
                                               img_shape)
                # TODO smap
                if id_channel is True:
                    img_t[c, ...] = sp.to_device(img_tt)
                else:
                    if smap is None:
                        img_t += xp.abs(img_tt)**2
                    else:
                        img_t += sp.to_device(
                            img_tt *
                            xp.conj(sp.to_device(smap[c, ...], device)))

            img_t = sp.to_device(img_t)
        img.append(img_t)

    return np.asarray(img)
Example #25
0
def circulant_precond(mps,
                      weights=None,
                      coord=None,
                      lamda=0,
                      device=sp.cpu_device):
    r"""Compute circulant preconditioner.

    Considers the optimization problem:

    .. math::
        \min_P \| A^H A - F P F^H  \|_2^2

    where A is the Sense operator,
    and F is a unitary Fourier transform operator.

    Args:
        mps (array): sensitivity maps of shape [num_coils] + image shape.
        weights (array): k-space weights.
        coord (array): k-space coordinates of shape [...] + [ndim].
        lamda (float): regularization.

    Returns:
        array: circulant preconditioner of image shape.

    """
    if coord is not None:
        coord = sp.to_device(coord, device)

    if weights is not None:
        weights = sp.to_device(weights, device)

    dtype = mps.dtype
    device = sp.Device(device)
    xp = device.xp

    mps_shape = list(mps.shape)
    img_shape = mps_shape[1:]
    img2_shape = [i * 2 for i in img_shape]
    ndim = len(img_shape)

    scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)**2
    with device:
        idx = (slice(None, None, 2), ) * ndim
        if coord is None:
            ones = xp.zeros(img2_shape, dtype=dtype)
            if weights is None:
                ones[idx] = 1
            else:
                ones[idx] = weights**0.5

            psf = sp.ifft(ones)
        else:
            coord2 = coord * 2
            ones = xp.ones(coord.shape[:-1], dtype=dtype)
            if weights is not None:
                ones *= weights**0.5

            psf = sp.nufft_adjoint(ones, coord2, img2_shape)

        p_inv = 0
        for mps_i in mps:
            mps_i = sp.to_device(mps_i, device)
            xcorr_fourier = xp.abs(sp.fft(xp.conj(mps_i), img2_shape))**2
            xcorr = sp.ifft(xcorr_fourier)
            xcorr *= psf
            p_inv_i = sp.fft(xcorr)
            p_inv_i = p_inv_i[idx]
            p_inv_i *= scale
            if weights is not None:
                p_inv_i *= weights**0.5

            p_inv += p_inv_i

        p_inv += lamda
        p_inv[p_inv == 0] = 1
        p = 1 / p_inv

        return p.astype(dtype)
Example #26
0
              np.int(np.max(traj[..., 1]) - np.min(traj[..., 1])),
              np.int(np.max(traj[..., 2]) - np.min(traj[..., 2])))

    ## calibration
    print('Calibration...')
    ksp = np.reshape(np.transpose(data, (2, 1, 0, 3, 4, 5)),
                     (nCoil, nphase * npe, nfe))
    dcf2 = np.reshape(np.transpose(dcf**2, (2, 1, 0, 3, 4, 5)),
                      (nphase * npe, nfe))
    coord = np.reshape(np.transpose(traj, (2, 1, 0, 3, 4, 5)),
                       (nphase * npe, nfe, 3))

    mps = ext.jsens_calib(ksp,
                          coord,
                          dcf2,
                          device=sp.Device(device),
                          ishape=tshape)
    S = sp.linop.Multiply(tshape, mps)

    imgL = cfl.read_cfl(fname + '_mrL')
    imgL = np.squeeze(imgL)

    ## registration
    print('Registration...')
    M_fields = []
    iM_fields = []
    if reg_flag is 1:
        for i in range(nphase):
            M_field, iM_field = reg.ANTsReg(np.abs(imgL[n_ref]),
                                            np.abs(imgL[i]))
            M_fields.append(M_field)
Example #27
0
    traj = traj[...,:nf_e,:]
    data = data[...,:nf_e,:]
    dcf = dcf[...,:nf_e,:]

    nphase,nEcalib,nCoil,npe,nfe,_ = data.shape
    tshape = (np.int(np.max(traj[...,0])-np.min(traj[...,0]))
              ,np.int(np.max(traj[...,1])-np.min(traj[...,1]))
              ,np.int(np.max(traj[...,2])-np.min(traj[...,2])))

    ## calibration
    print('Calibration...')
    ksp = np.reshape(np.transpose(data,(2,1,0,3,4,5)),(nCoil,nphase*npe,nfe))
    dcf2 = np.reshape(np.transpose(dcf**2,(2,1,0,3,4,5)),(nphase*npe,nfe))
    coord = np.reshape(np.transpose(traj,(2,1,0,3,4,5)),(nphase*npe,nfe,3))

    mps = ext.jsens_calib(ksp,coord,dcf2,device = sp.Device(device),ishape = tshape)
    S = sp.linop.Multiply(tshape, mps)

    imgL = cfl.read_cfl(fname+'_mrL')
    imgL = np.squeeze(imgL)

    ## registration
    print('Registration...')
    M_fields = []
    iM_fields = []
    if reg_flag is 1:
        for i in range(nphase):
            M_field, iM_field = reg.ANTsReg(np.abs(imgL[n_ref]), np.abs(imgL[i]))
            M_fields.append(M_field)
            iM_fields.append(iM_field)
        M_fields = np.asarray(M_fields)
 def use_device(self, device):
     self.device = sp.Device(device)
     self.L = [sp.to_device(L_j, self.device) for L_j in self.L]
     self.R = [sp.to_device(R_j, self.device) for R_j in self.R]
Example #29
0
import sigpy as sp
import sigpy.mri as mr
import sigpy.plot as pl
import matplotlib.pyplot as plt
import numpy as np

# Set parameters and load dataset
max_iter = 30
max_cg_iter = 5
lamda = 0.001

ksp_file = 'data/liver/ksp.npy'
coord_file = 'data/liver/coord.npy'

device = sp.Device(-1)

xp = device.xp
device.use()

# Load datasets.
ksp = xp.load(ksp_file)
coord = xp.load(coord_file)

print(f'K-space shape: {ksp.shape}')
print(f'K-space dtype: {ksp.dtype}')
print(f'K-space (min, max): ({np.abs(ksp).min()}, {np.abs(ksp).max()})')
print(f'Coord shape: {coord.shape}')  # (na, ns, 2)
print(f'Coord shape: {coord.dtype}')
print(f'Coord (min, max): ({coord.min()}, {coord.max()})')

plt.ion()
Example #30
0
def Demons(If,
           Im,
           level,
           device=-1,
           rho=0.7,
           sigmas_f=[2, 2, 2, 3],
           sigmas_e=[2, 2, 2, 2],
           sigmas_s=[.5, .5, 1, 1],
           iters=[40, 40, 40, 20, 20]):
    ### normalization??
    Im = np.abs(Im)
    m_scale = np.max(Im)
    Im = Im / m_scale
    If = np.abs(If)
    If = If / m_scale

    ### registration
    M = np.zeros(Im.shape + (3, ))
    Mt = np.zeros(Im.shape + (3, ))
    for k in range(level):
        print('Demons Level:{}'.format(k))
        ### hyperparameter assignment
        scale = 2**(level - k - 1)
        sigma_f = sigmas_f[k]
        sigma_e = sigmas_e[k]
        sigma_s = sigmas_s[k]
        iter_each_level = iters[k]

        ###
        Ift = ndimage.zoom(If, zoom=1 / scale, order=2)
        Ift = ndimage.gaussian_filter(Ift, sigma=sigma_s, truncate=2.0)
        Imt = ndimage.zoom(Im, zoom=1 / scale, order=2)
        Imt = ndimage.gaussian_filter(Imt, sigma=sigma_s, truncate=2.0)
        Imask = pmask(Imt + Ift, 1e-2)

        Isizet = Ift.shape
        Mt = M_scale(Mt, Isizet)
        uo = np.zeros_like(Mt)
        for i in range(iter_each_level):

            Imm = interp(Imt, Mt, device=sp.Device(device), k_id=1)
            Ifm = interp(Ift, -Mt, device=sp.Device(device), k_id=1)
            dI = Ifm - Imm
            Is = (Ifm + Imm) / 2
            # Is = ndimage.gaussian_filter((Ifm+Imm)/2,sigma=sigma_s,truncate=2.0)

            gIx, gIy, gIz = imgrad3d(Is)
            gI = np.sqrt(np.abs(gIx**2 + gIy**2 + gIz**2) + 1e-6)
            discriminator = gI**2 + np.abs(dI)**2
            dI = dI * 3.0
            ux = -dI * gIx / discriminator
            uy = -dI * gIy / discriminator
            uz = -dI * gIz / discriminator

            mask = (gI < 1e-4) | (~Imask)
            ux[np.isnan(ux) | mask] = 0
            uy[np.isnan(uy) | mask] = 0
            uz[np.isnan(uz) | mask] = 0

            ux = np.maximum(np.minimum(ux, 1), -1)
            uy = np.maximum(np.minimum(uy, 1), -1)
            uz = np.maximum(np.minimum(uz, 1), -1)
            ux = ndimage.gaussian_filter(ux, sigma=sigma_f)
            uy = ndimage.gaussian_filter(uy, sigma=sigma_f)
            uz = ndimage.gaussian_filter(uz, sigma=sigma_f)

            Mt[..., 0] = Mt[..., 0] + rho * ux + (1 - rho) * uo[..., 0]
            Mt[..., 1] = Mt[..., 1] + rho * uy + (1 - rho) * uo[..., 1]
            Mt[..., 2] = Mt[..., 2] + rho * uz + (1 - rho) * uo[..., 2]
            uo[..., 0] = ux
            uo[..., 1] = uy
            uo[..., 2] = uz

            Mt[..., 0] = ndimage.gaussian_filter(Mt[..., 0], sigma=sigma_e)
            Mt[..., 1] = ndimage.gaussian_filter(Mt[..., 1], sigma=sigma_e)
            Mt[..., 2] = ndimage.gaussian_filter(Mt[..., 2], sigma=sigma_e)

    ### TODO inverse combination (right now just double)
    M = M_scale(Mt * 2, Im.shape)
    return M