Ejemplo n.º 1
0
    def __init__(self,
                 maps,
                 mask,
                 l2lam=False,
                 img_shape=None,
                 use_sigpy=False,
                 noncart=False,
                 num_spatial_dims=2):
        super(MultiChannelMRI, self).__init__()
        self.maps = maps
        self.mask = mask
        self.l2lam = l2lam
        self.img_shape = img_shape
        self.noncart = noncart
        self._normal = None
        self.num_spatial_dims = num_spatial_dims

        if self.maps.shape[1] == 1:
            self.single_channel = True
        else:
            self.single_channel = False

        if self.noncart:
            assert use_sigpy, 'Must use SigPy for NUFFT!'

        if use_sigpy:  # FIXME: Not yet Implemented for 3D
            from sigpy import from_pytorch, to_device, Device
            sp_device = Device(self.maps.device.index)
            self.maps = to_device(from_pytorch(self.maps, iscomplex=True),
                                  device=sp_device)
            self.mask = to_device(from_pytorch(self.mask, iscomplex=False),
                                  device=sp_device)
            self.img_shape = self.img_shape[:-1]  # convert R^2N to C^N
            self._build_model_sigpy()
Ejemplo n.º 2
0
    def __init__(self,
                 y,
                 mps,
                 lamda=0,
                 weights=None,
                 tseg=None,
                 coord=None,
                 device=sp.cpu_device,
                 coil_batch_size=None,
                 comm=None,
                 show_pbar=True,
                 transp_nufft=False,
                 **kwargs):
        weights = _estimate_weights(y, weights, coord)
        if weights is not None:
            y = sp.to_device(y * weights**0.5, device=device)
        else:
            y = sp.to_device(y, device=device)

        A = linop.Sense(mps,
                        coord=coord,
                        weights=weights,
                        tseg=tseg,
                        coil_batch_size=coil_batch_size,
                        comm=comm,
                        transp_nufft=transp_nufft)

        if comm is not None:
            show_pbar = show_pbar and comm.rank == 0

        super().__init__(A, y, lamda=lamda, show_pbar=show_pbar, **kwargs)
Ejemplo n.º 3
0
    def __init__(self,
                 maps,
                 mask,
                 l2lam=False,
                 img_shape=None,
                 use_sigpy=False,
                 noncart=False):
        super(MultiChannelMRI, self).__init__()
        self.maps = maps
        self.mask = mask
        self.l2lam = l2lam
        self.img_shape = img_shape
        self.noncart = noncart
        self._normal = None

        if self.noncart:
            assert use_sigpy, 'Must use SigPy for NUFFT!'

        if use_sigpy:
            from sigpy import from_pytorch, to_device, Device
            sp_device = Device(self.maps.device.index)
            self.maps = to_device(from_pytorch(self.maps, iscomplex=True),
                                  device=sp_device)
            self.mask = to_device(from_pytorch(self.mask, iscomplex=False),
                                  device=sp_device)
            self.img_shape = self.img_shape[:-1]  # convert R^2N to C^N
            self._build_model_sigpy()
Ejemplo n.º 4
0
Archivo: app.py Proyecto: jtamir/sigpy
    def __init__(self,
                 y,
                 mps,
                 lamda,
                 weights=None,
                 coord=None,
                 device=sp.cpu_device,
                 coil_batch_size=None,
                 **kwargs):
        weights = _estimate_weights(y, weights, coord)
        if weights is not None:
            y = sp.to_device(y * weights**0.5, device=device)
        else:
            y = sp.to_device(y, device=device)

        A = linop.Sense(mps,
                        coord=coord,
                        weights=weights,
                        coil_batch_size=coil_batch_size)

        G = sp.linop.Gradient(A.ishape)
        proxg = sp.prox.L1Reg(G.oshape, lamda)

        def g(x):
            device = sp.get_device(x)
            xp = device.xp
            with device:
                return lamda * xp.sum(xp.abs(x))

        super().__init__(A, y, proxg=proxg, g=g, G=G, **kwargs)
Ejemplo n.º 5
0
    def __init__(self, y, mps, lamda,
                 weights=None, coord=None,
                 wave_name='db4', device=sp.cpu_device,
                 coil_batch_size=None, comm=None, show_pbar=True, **kwargs):
        weights = _estimate_weights(y, weights, coord)
        if weights is not None:
            y = sp.to_device(y * weights**0.5, device=device)
        else:
            y = sp.to_device(y, device=device)

        A = linop.Sense(mps, coord=coord, weights=weights,
                        comm=comm, coil_batch_size=coil_batch_size)
        img_shape = mps.shape[1:]
        W = sp.linop.Wavelet(img_shape, wave_name=wave_name)
        proxg = sp.prox.UnitaryTransform(sp.prox.L1Reg(W.oshape, lamda), W)

        def g(input):
            device = sp.get_device(input)
            xp = device.xp
            with device:
                return lamda * xp.sum(xp.abs(W(input)))
        if comm is not None:
            show_pbar = show_pbar and comm.rank == 0

        super().__init__(A, y, proxg=proxg, g=g, show_pbar=show_pbar, **kwargs)
    def _AHyH_L(self, t):
        # Download k-space arrays.
        tr_start = t * self.tr_per_frame
        tr_end = (t + 1) * self.tr_per_frame
        coord_t = sp.to_device(self.coord[tr_start:tr_end], self.device)
        dcf_t = sp.to_device(self.dcf[tr_start:tr_end], self.device)
        ksp_t = sp.to_device(self.ksp[:, tr_start:tr_end], self.device)

        # A^H(y_t)
        AHy_t = 0
        for c in range(self.C):
            mps_c = sp.to_device(self.mps[c], self.device)
            AHy_tc = sp.nufft_adjoint(dcf_t * ksp_t[c],
                                      coord_t,
                                      oshape=self.img_shape)
            AHy_tc *= self.xp.conj(mps_c)
            AHy_t += AHy_tc

        if self.comm is not None:
            self.comm.allreduce(AHy_t)

        for j in range(self.J):
            AHy_tj = self.B[j].H(AHy_t)
            self.R[j][t] = self.xp.sum(AHy_tj * self.xp.conj(self.L[j]),
                                       axis=range(-self.D, 0),
                                       keepdims=True)
Ejemplo n.º 7
0
    def __init__(self,
                 y,
                 L,
                 lamda=0.005,
                 mode='full',
                 multi_channel=False,
                 device=sp.cpu_device,
                 **kwargs):
        self.y = sp.to_device(y, device)
        self.L = sp.to_device(L, device)
        self.lamda = lamda
        self.mode = mode
        self.multi_channel = multi_channel
        self.device = device

        self._get_params()
        self.A_R = sp.linop.ConvolveInput(
            self.R_shape,
            self.L,
            mode=self.mode,
            input_multi_channel=True,
            output_multi_channel=self.multi_channel)

        proxg_R = sp.prox.L1Reg(self.R_shape, lamda)
        super().__init__(self.A_R, self.y, proxg=proxg_R, **kwargs)
Ejemplo n.º 8
0
    def __init__(self, y, mps, lamda,
                 weights=None, coord=None, device=sp.cpu_device,
                 coil_batch_size=None, comm=None, show_pbar=True, **kwargs):
        weights = _estimate_weights(y, weights, coord)
        if weights is not None:
            y = sp.to_device(y * weights**0.5, device=device)
        else:
            y = sp.to_device(y, device=device)

        A = linop.Sense(mps, coord=coord, weights=weights,
                        comm=comm, coil_batch_size=coil_batch_size)

        G = sp.linop.FiniteDifference(A.ishape)
        proxg = sp.prox.L1Reg(G.oshape, lamda)

        def g(x):
            device = sp.get_device(x)
            xp = device.xp
            with device:
                return lamda * xp.sum(xp.abs(x))

        if comm is not None:
            show_pbar = show_pbar and comm.rank == 0

        super().__init__(A, y, proxg=proxg, g=g, G=G, show_pbar=show_pbar,
                         **kwargs)
Ejemplo n.º 9
0
Archivo: app.py Proyecto: jtamir/sigpy
    def __init__(self,
                 y,
                 mps,
                 eps,
                 wave_name='db4',
                 weights=None,
                 coord=None,
                 device=sp.cpu_device,
                 coil_batch_size=None,
                 **kwargs):
        weights = _estimate_weights(y, weights, coord)
        if weights is not None:
            y = sp.to_device(y * weights**0.5, device=device)
        else:
            y = sp.to_device(y, device=device)

        A = linop.Sense(mps,
                        coord=coord,
                        weights=weights,
                        coil_batch_size=coil_batch_size)
        img_shape = mps.shape[1:]
        W = sp.linop.Wavelet(img_shape, wave_name=wave_name)
        proxg = sp.prox.UnitaryTransform(sp.prox.L1Reg(W.oshape, 1), W)

        super().__init__(A, y, proxg, eps, **kwargs)
Ejemplo n.º 10
0
    def fftshift_and_pad_to(sparse_data: Sparse4DData, pad_to_frame_dimensions) -> Sparse4DData:
        indices = sparse_data.indices
        scan_dimensions = sparse_data.scan_dimensions
        frame_dimensions = sparse_data.frame_dimensions
        center_frame = frame_dimensions / 2

        threadsperblock = (16, 16)
        blockspergrid = tuple(np.ceil(np.array(indices.shape[:2]) / threadsperblock).astype(np.int))

        no_count_indicator_old = np.iinfo(indices.dtype).max
        center_frame = sp.to_device(center_frame,0)
        xp = sp.backend.get_array_module(center_frame)

        inds = np.prod(pad_to_frame_dimensions)
        if inds > 2**15:
            dtype = xp.int64
        elif inds > 2**15:
            dtype = xp.int32
        elif inds > 2**8:
            dtype = xp.int16
        else:
            dtype = xp.uint8

        no_count_indicator_new = xp.iinfo(dtype).max

        inds = sp.to_device(indices,0).astype(dtype)
        pad_to_frame_dimensions = sp.to_device(pad_to_frame_dimensions,0)

        scan_dimensions = sp.to_device(scan_dimensions,0)

        fftshift_pad_kernel[blockspergrid, threadsperblock](inds, center_frame, scan_dimensions,  pad_to_frame_dimensions,
                                                            no_count_indicator_old, no_count_indicator_new)
        sparse_data.indices = inds.get()
        sparse_data.frame_dimensions = pad_to_frame_dimensions.get()
        return sparse_data
Ejemplo n.º 11
0
    def _get_data(self):
        if self.coord is None:
            self.img_shape = list(self.y.shape[1:])
            ndim = len(self.img_shape)

            self.y = sp.resize(
                self.y, [self.num_coils] + ndim * [self.ksp_calib_width])

            if self.weights is not None:
                self.weights = sp.resize(
                    self.weights, ndim * [self.ksp_calib_width])

        else:
            self.img_shape = sp.estimate_shape(self.coord)
            calib_idx = np.amax(np.abs(self.coord), axis=-
                                1) < self.ksp_calib_width / 2

            self.coord = self.coord[calib_idx]
            self.y = self.y[:, calib_idx]

            if self.weights is not None:
                self.weights = self.weights[calib_idx]

        if self.weights is None:
            self.y = sp.to_device(self.y / np.abs(self.y).max(), self.device)
        else:
            self.y = sp.to_device(self.weights**0.5 *
                                  self.y / np.abs(self.y).max(), self.device)

        if self.coord is not None:
            self.coord = sp.to_device(self.coord, self.device)
        if self.weights is not None:
            self.weights = sp.to_device(self.weights, self.device)

        self.weights = _estimate_weights(self.y, self.weights, self.coord)
Ejemplo n.º 12
0
    def __init__(self,
                 y,
                 mps,
                 eps,
                 weights=None,
                 coord=None,
                 device=sp.cpu_device,
                 coil_batch_size=None,
                 comm=None,
                 show_pbar=True,
                 **kwargs):
        weights = _estimate_weights(y, weights, coord)
        if weights is not None:
            y = sp.to_device(y * weights**0.5, device=device)
        else:
            y = sp.to_device(y, device=device)

        A = linop.Sense(mps,
                        coord=coord,
                        weights=weights,
                        comm=comm,
                        coil_batch_size=coil_batch_size)
        G = sp.linop.FiniteDifference(A.ishape)
        proxg = sp.prox.L1Reg(G.oshape, 1)

        if comm is not None:
            show_pbar = show_pbar and comm.rank == 0

        super().__init__(A, y, proxg, eps, G=G, show_pbar=show_pbar, **kwargs)
Ejemplo n.º 13
0
Archivo: plot.py Proyecto: jtamir/sigpy
    def update_data(self):

        idx = []
        for i in range(self.ndim):
            if i == self.z:
                idx.append(slice(None, None, self.flips[i]))
            else:
                idx.append(self.slices[i])

        idx = tuple(idx)
        if idx:
            datav = sp.to_device(self.data[idx])
        else:
            datav = sp.to_device(self.data)

        # if self.z is not None:
        #     datav_dims = [self.z] + datav_dims
        coordv = sp.to_device(self.coord)

        if self.mode == 'm':
            datav = np.abs(datav)
        elif self.mode == 'p':
            datav = np.angle(datav)
        elif self.mode == 'r':
            datav = np.real(datav)
        elif self.mode == 'i':
            datav = np.imag(datav)
        elif self.mode == 'l':
            eps = 1e-31
            datav = np.log(np.abs(datav) + eps)

        datav = datav.ravel()
        if self.vmin is None:
            if datav.min() == datav.max():
                self.vmin = 0
            else:
                self.vmin = datav.min()

        if self.vmax is None:
            self.vmax = datav.max()

        if self.axsc is None:
            self.axsc = self.ax.scatter(
                coordv[..., 0].ravel(),
                coordv[..., 1].ravel(),
                c=datav,
                s=1,
                linewidths=0,
                cmap='gray',
                vmin=self.vmin,
                vmax=self.vmax,
            )

        else:
            self.axsc.set_offsets(coordv.T.reshape([-1, 2]))
            self.axsc.set_color(datav)
Ejemplo n.º 14
0
def nufft1(img,
           traj,
           device=sp.Device(-1),
           smap=None,
           dcf=None,
           batch=40000,
           id_channel=False):
    xp = device.xp
    N_phase = img.shape[0]
    ksp = []
    for num_ph in range(N_phase):
        img_t = img[num_ph]
        coord_t = traj[num_ph]
        if dcf is not None:
            dcf_t = dcf[num_ph]
        img_shape = sp.estimate_shape(coord_t)
        num_tr, num_ro, ndim = coord_t.shape
        if id_channel is True:
            num_coils = img_t.shape[0]
        else:
            if smap is None:
                smap = np.array([1])
                num_coils = 1
            else:
                num_coils = smap.shape[0]

        with device:
            ksp_t = np.zeros((num_coils, num_tr, num_ro), dtype=np.complex64)
            for c in range(num_coils):
                if id_channel is True:
                    img_tc = sp.to_device(img_t[c, ...], device)
                else:
                    img_tc = sp.to_device(img_t * smap[c, ...], device)
                for seg in range((num_tr - 1) // batch + 1):
                    coord_tt = sp.to_device(
                        coord_t[seg * batch:np.minimum((seg + 1) *
                                                       batch, num_tr), ...],
                        device)
                    ksp_tt = sp.nufft(img_tc, coord_tt)
                    if dcf is not None:
                        dcf_tt = sp.to_device(
                            dcf_t[seg * batch:np.minimum((seg + 1) *
                                                         batch, num_tr), ...],
                            device)
                        ksp_tt = dcf_tt * ksp_tt

                    ksp_t[c, seg * batch:np.minimum((seg + 1) * batch, num_tr),
                          ...] = sp.to_device(ksp_tt)

        ksp.append(ksp_t)

    return np.asarray(ksp)
Ejemplo n.º 15
0
 def _determine_center_and_radius(data : Sparse4DData, manual=False, size=25, threshold=3e-1):
     sh = np.concatenate([data.scan_dimensions,data.frame_dimensions])
     c = np.zeros((2,))
     c[:] = (sh[-1] // 2, sh[-2] // 2)
     radius = np.ones((1,)) * sh[-1] // 2
     inds = sp.to_device(data.indices[:size, :size].astype(np.uint32),0)
     cts = sp.to_device(data.counts[:size, :size].astype(np.uint32),0)
     xp = sp.backend.get_array_module(inds)
     dc_subset = sparse_to_dense_datacube_crop(inds,cts, (size,size), data.frame_dimensions, c, radius, bin=2)
     dcs = xp.sum(dc_subset, (0, 1))
     m1 = dcs.get()
     m = (gaussian(m1.astype(np.float32),2) > m1.max() * threshold).astype(np.float)
     r, y0, x0 = get_probe_size(m)
     return 2 * np.array([y0,x0]), r*2
Ejemplo n.º 16
0
def autofov(ksp, coord, dcf, num_ro=100, device=sp.cpu_device, thresh=0.1):
    """Automatic estimation of FOV.

    FOV is estimated by thresholding a low resolution gridded image.
    coord will be modified in-place.

    Args:
        ksp (array): k-space measurements of shape (C, num_tr, num_ro, D).
            where C is the number of channels,
            num_tr is the number of TRs, num_ro is the readout points,
            and D is the number of spatial dimensions.
        coord (array): k-space coordinates of shape (num_tr, num_ro, D).
        dcf (array): density compensation factor of shape (num_tr, num_ro).
        num_ro (int): number of read-out points.
        device (Device): computing device.
        thresh (float): threshold between 0 and 1.

    """
    device = sp.Device(device)
    xp = device.xp
    with device:
        kspc = ksp[:, :, :num_ro]
        coordc = coord[:, :num_ro, :]
        dcfc = dcf[:, :num_ro]
        coordc2 = sp.to_device(coordc * 2, device)

        num_coils = len(kspc)
        imgc_shape = np.array(sp.estimate_shape(coordc))
        imgc2_shape = sp.estimate_shape(coordc2)
        imgc2_center = [i // 2 for i in imgc2_shape]
        imgc2 = sp.nufft_adjoint(sp.to_device(dcfc * kspc, device), coordc2,
                                 [num_coils] + imgc2_shape)
        imgc2 = xp.sum(xp.abs(imgc2)**2, axis=0)**0.5
        if imgc2.ndim == 3:
            imgc2_cor = imgc2[:, imgc2.shape[1] // 2, :]
            thresh *= imgc2_cor.max()
        else:
            thresh *= imgc2.max()

        boxc = imgc2 > thresh

        boxc = sp.to_device(boxc)
        boxc_idx = np.nonzero(boxc)
        boxc_shape = np.array([
            int(np.abs(boxc_idx[i] - imgc2_center[i]).max()) * 2
            for i in range(imgc2.ndim)
        ])

        img_scale = boxc_shape / imgc_shape
        coord *= img_scale
Ejemplo n.º 17
0
def interp(I, M_field, device=sp.Device(-1), k_id=1, deblur=True):
    # b spline interpolation
    N = 64
    if k_id is 0:
        kernel = [(3 * (x / N)**3 - 6 * (x / N)**2 + 4) / 6
                  for x in range(0, N)] + [(2 - x / N)**3 / 6
                                           for x in range(N, 2 * N)]
        dkernel = np.array([-.2, 1.4, -.2])

        k_wid = 4
    else:
        kernel = [1 - x / (2 * N) for x in range(0, 2 * N)]
        dkernel = np.array([0, 1, 0])
        deblur = False
        k_wid = 2
    kernel = np.asarray(kernel)

    c_device = sp.get_device(I)
    ndim = M_field.shape[-1]

    # 2d/3d
    if ndim is 3:
        dkernel = dkernel[:, None, None] * dkernel[None, :,
                                                   None] * dkernel[None,
                                                                   None, :]
        Nx, Ny, Nz = I.shape
        my, mx, mz = np.meshgrid(np.arange(Ny), np.arange(Nx), np.arange(Nz))
        m = np.stack((mx, my, mz), axis=-1)
        M_field = M_field + m
    else:
        dkernel = dkernel[:, None] * dkernel[None, :]
        Nx, Ny = I.shape
        my, mx = np.meshgrid(np.arange(Ny), np.arange(Nx))
        m = np.stack((mx, my, mz), axis=-1)
        M_field = M_field + m
    # TODO remove out of range values

    # image warp

    g_device = device
    I = sp.to_device(input=I, device=g_device)
    I = sp.interp.interpolate(I, k_wid, kernel, M_field.astype(np.float64))
    # deconv
    if deblur is True:
        sp.conv.convolve(I, dkernel)
    I = sp.to_device(input=I, device=c_device)

    return I
    def gradf(self, mrimg):
        out = self.xp.zeros_like(mrimg)
        for b in range(self.B):
            for c in range(self.C):
                mps_c = sp.to_device(self.mps[c], self.device)
                out[b] += sp.nufft_adjoint(
                    self.bdcf[b] *
                    (sp.nufft(mrimg[b] * mps_c, self.bcoord[b]) -
                     self.bksp[b][c]),
                    self.bcoord[b],
                    oshape=mrimg.shape[1:]) * self.xp.conj(mps_c)

        if self.comm is not None:
            self.comm.allreduce(out)

        eps = 1e-31
        for b in range(self.B):
            if b > 0:
                diff = mrimg[b] - mrimg[b - 1]
                sp.axpy(out[b], self.lamda, diff / (self.xp.abs(diff) + eps))

            if b < self.B - 1:
                diff = mrimg[b] - mrimg[b + 1]
                sp.axpy(out[b], self.lamda, diff / (self.xp.abs(diff) + eps))

        return out
    def _normalize(self):
        # Normalize using first phase.
        with device:
            mrimg_adj = 0
            for c in range(self.C):
                mrimg_c = sp.nufft_adjoint(self.bksp[0][c] * self.bdcf[0],
                                           self.bcoord[0], self.img_shape)
                mrimg_c *= self.xp.conj(sp.to_device(mps[c], device))
                mrimg_adj += mrimg_c

            if comm is not None:
                comm.allreduce(mrimg_adj)

            # Get maximum eigenvalue.
            F = sp.linop.NUFFT(self.img_shape, self.bcoord[0])
            W = sp.linop.Multiply(F.oshape, self.bdcf[0])
            max_eig = sp.app.MaxEig(F.H * W * F,
                                    max_iter=self.max_power_iter,
                                    dtype=ksp.dtype,
                                    device=device,
                                    show_pbar=self.show_pbar).run()

            # Normalize
            self.alpha /= max_eig
            self.lamda *= max_eig * self.xp.abs(mrimg_adj).max().item()
Ejemplo n.º 20
0
    def update_image(self):
        # Extract slice.
        idx = []
        for i in range(self.ndim):
            if i in [self.x, self.y, self.z, self.c]:
                idx.append(slice(None, None, self.flips[i]))
            else:
                idx.append(self.slices[i])

        idx = tuple(idx)
        imv = sp.to_device(self.im[idx])

        # Transpose to have [z, y, x, c].
        imv_dims = [self.y, self.x]
        if self.z is not None:
            imv_dims = [self.z] + imv_dims

        if self.c is not None:
            imv_dims = imv_dims + [self.c]

        imv = np.transpose(imv, np.argsort(np.argsort(imv_dims)))
        imv = array_to_image(imv, color=self.c is not None)

        if self.mode == 'm':
            imv = np.abs(imv)
        elif self.mode == 'p':
            imv = np.angle(imv)
        elif self.mode == 'r':
            imv = np.real(imv)
        elif self.mode == 'i':
            imv = np.imag(imv)
        elif self.mode == 'l':
            imv = np.abs(imv)
            imv = np.log(imv, out=np.ones_like(imv) * -31, where=imv != 0)

        if self.vmin is None:
            self.vmin = imv.min()

        if self.vmax is None:
            self.vmax = imv.max()

        if self.axim is None:
            self.axim = self.ax.imshow(
                imv,
                vmin=self.vmin,
                vmax=self.vmax,
                cmap='gray',
                origin='lower',
                interpolation=self.interpolation,
                aspect=1.0,
                extent=[
                    0,
                    imv.shape[1],
                    0,
                    imv.shape[0]])

        else:
            self.axim.set_data(imv)
            self.axim.set_extent([0, imv.shape[1], 0, imv.shape[0]])
            self.axim.set_clim(self.vmin, self.vmax)
Ejemplo n.º 21
0
    def __init__(self,
                 ksp,
                 calib_width=24,
                 thresh=0.01,
                 kernel_width=6,
                 crop=0.8,
                 max_iter=100,
                 device=sp.cpu_device,
                 output_eigenvalue=False,
                 show_pbar=True):
        self.device = sp.Device(device)
        self.output_eigenvalue = output_eigenvalue
        self.crop = crop

        img_ndim = ksp.ndim - 1
        num_coils = len(ksp)
        with sp.get_device(ksp):
            # Get calibration region
            calib_shape = [num_coils] + [calib_width] * img_ndim
            calib = sp.resize(ksp, calib_shape)
            calib = sp.to_device(calib, device)

        xp = self.device.xp
        with self.device:
            # Get calibration matrix
            kernel_shape = [num_coils] + [kernel_width] * img_ndim
            kernel_strides = [1] * (img_ndim + 1)
            mat = sp.array_to_blocks(calib, kernel_shape, kernel_strides)
            mat = mat.reshape([-1, sp.prod(kernel_shape)])

            # Perform SVD on calibration matrix
            _, S, VH = xp.linalg.svd(mat, full_matrices=False)
            VH = VH[S > thresh * S.max(), :]

            # Get kernels
            num_kernels = len(VH)
            kernels = VH.reshape([num_kernels] + kernel_shape)
            img_shape = ksp.shape[1:]

            # Get covariance matrix in image domain
            AHA = xp.zeros(img_shape[::-1] + (num_coils, num_coils),
                           dtype=ksp.dtype)
            for kernel in kernels:
                img_kernel = sp.ifft(sp.resize(kernel, ksp.shape),
                                     axes=range(-img_ndim, 0))
                aH = xp.expand_dims(img_kernel.T, axis=-1)
                a = xp.conj(aH.swapaxes(-1, -2))
                AHA += aH @ a

            AHA *= (sp.prod(img_shape) / kernel_width**img_ndim)
            self.mps = xp.ones(ksp.shape[::-1] + (1, ), dtype=ksp.dtype)

            alg = sp.alg.PowerMethod(
                lambda x: AHA @ x,
                self.mps,
                norm_func=lambda x: xp.sum(
                    xp.abs(x)**2, axis=-2, keepdims=True)**0.5,
                max_iter=max_iter)

        super().__init__(alg, show_pbar=show_pbar)
Ejemplo n.º 22
0
 def _broadcast_check(self, x):
     if (len(x.shape) == self.max_dims):
         return x
     x = sp.to_device(x, self.cpu)
     while (len(x.shape) < self.max_dims):
         x = np.expand_dims(x, axis=0)
     return x
Ejemplo n.º 23
0
 def crop_symmetric_center_(self, center, max_radius = None):
     if max_radius is None:
         y_min_radius = np.min([center[0], self.frame_dimensions[0] - center[0]])
         x_min_radius = np.min([center[1], self.frame_dimensions[1] - center[1]])
         max_radius = np.min([y_min_radius, x_min_radius])
     inds = sp.to_device(self.indices,0)
     frame_dimensions = sp.to_device(self.frame_dimensions, 0)
     xp = sp.backend.get_array_module(inds)
     new_frames, new_frame_dimensions = crop_symmetric_around_center(inds,frame_dimensions, center, max_radius)
     print(f'old frames shape: {self.indices.shape}')
     print(f'new frames shape: {new_frames.shape}')
     print(f'old frames frame_dimensions: {self.frame_dimensions}')
     print(f'new frames frame_dimensions: {new_frame_dimensions}')
     self.indices = new_frames
     self.counts = np.zeros(self.indices.shape, dtype=np.bool)
     self.counts[self.indices != np.iinfo(self.indices.dtype).max] = 1
     self.frame_dimensions = new_frame_dimensions
    def __init__(self,
                 ksp,
                 coord,
                 dcf,
                 mps,
                 resp,
                 B,
                 lamda=1e-6,
                 alpha=1,
                 beta=0.5,
                 max_power_iter=10,
                 max_iter=300,
                 device=sp.cpu_device,
                 margin=10,
                 coil_batch_size=None,
                 comm=None,
                 show_pbar=True,
                 **kwargs):
        self.B = B
        self.C = len(mps)
        self.mps = mps
        self.device = sp.Device(device)
        self.xp = device.xp
        self.alpha = alpha
        self.beta = beta
        self.lamda = lamda
        self.max_iter = max_iter
        self.max_power_iter = max_power_iter
        self.comm = comm
        if comm is not None:
            self.show_pbar = show_pbar and comm.rank == 0

        self.img_shape = list(mps.shape[1:])

        bins = np.percentile(resp, np.linspace(0 + margin, 100 - margin,
                                               B + 1))
        self.bksp = []
        self.bcoord = []
        self.bdcf = []
        for b in range(B):
            idx = (resp >= bins[b]) & (resp < bins[b + 1])
            self.bksp.append(sp.to_device(ksp[:, idx], self.device))
            self.bcoord.append(sp.to_device(coord[idx], self.device))
            self.bdcf.append(sp.to_device(dcf[idx], self.device))

        self._normalize()
Ejemplo n.º 25
0
    def fftshift_(self):
        indices = self.indices
        scan_dimensions = self.scan_dimensions
        frame_dimensions = self.frame_dimensions
        center_frame = frame_dimensions / 2

        threadsperblock = (16, 16)
        blockspergrid = tuple(np.ceil(np.array(indices.shape[:2]) / threadsperblock).astype(np.int))

        no_count_indicator = np.iinfo(indices.dtype).max

        inds = sp.to_device(indices, 0)
        center_frame = sp.to_device(center_frame, 0)
        scan_dimensions = sp.to_device(scan_dimensions, 0)

        fftshift_kernel[blockspergrid, threadsperblock](inds, center_frame, scan_dimensions, no_count_indicator)
        self.indices = inds.get()
        return self
    def _get_img(self, t, idx=None):
        with self.device:
            img_t = 0
            for j in range(self.J):
                B_j = self._get_B(j)
                img_t += B_j(self.L[j] * self.R[j][t])[idx]

        img_t = sp.to_device(img_t, sp.cpu_device)
        return img_t
Ejemplo n.º 27
0
def gridding_recon(ksp, coord, dcf, T=1, device=sp.cpu_device):
    """ Gridding reconstruction.

    Args:
        ksp (array): k-space measurements of shape (C, num_tr, num_ro, D).
            where C is the number of channels,
            num_tr is the number of TRs, num_ro is the readout points,
            and D is the number of spatial dimensions.
        coord (array): k-space coordinates of shape (num_tr, num_ro, D).
        dcf (array): density compensation factor of shape (num_tr, num_ro).
        mps (array): sensitivity maps of shape (C, N_D, ..., N_1).
            where (N_D, ..., N_1) represents the image shape.
        T (int): number of frames.

    Returns:
        img (array): image of shape (T, N_D, ..., N_1).
    """
    device = sp.Device(device)
    xp = device.xp
    num_coils, num_tr, num_ro = ksp.shape
    tr_per_frame = num_tr // T
    img_shape = sp.estimate_shape(coord)

    with device:
        img = []
        for t in range(T):
            tr_start = t * tr_per_frame
            tr_end = (t + 1) * tr_per_frame
            coord_t = sp.to_device(coord[tr_start:tr_end], device)
            dcf_t = sp.to_device(dcf[tr_start:tr_end], device)

            img_t = 0
            for c in range(num_coils):
                logging.info(f'Reconstructing time {t}, coil {c}')
                ksp_tc = sp.to_device(ksp[c, tr_start:tr_end, :], device)

                img_t += xp.abs(
                    sp.nufft_adjoint(ksp_tc * dcf_t, coord_t, img_shape))**2

            img_t = img_t**0.5
            img.append(sp.to_device(img_t))

    img = np.stack(img)
    return img
Ejemplo n.º 28
0
Archivo: app.py Proyecto: jtamir/sigpy
    def __init__(self,
                 y,
                 mps,
                 lamda=0,
                 weights=None,
                 coord=None,
                 device=sp.cpu_device,
                 coil_batch_size=None,
                 **kwargs):
        weights = _estimate_weights(y, weights, coord)
        if weights is not None:
            y = sp.to_device(y * weights**0.5, device=device)
        else:
            y = sp.to_device(y, device=device)

        A = linop.Sense(mps,
                        coord=coord,
                        weights=weights,
                        coil_batch_size=coil_batch_size)

        super().__init__(A, y, lamda=lamda, **kwargs)
    def _normalize(self):
        with self.device:
            # Estimate maximum eigenvalue.
            coord_t = sp.to_device(self.coord[:self.tr_per_frame], self.device)
            dcf_t = sp.to_device(self.dcf[:self.tr_per_frame], self.device)
            F = sp.linop.NUFFT(self.img_shape, coord_t)
            W = sp.linop.Multiply(F.oshape, dcf_t)

            max_eig = sp.app.MaxEig(F.H * W * F,
                                    max_iter=self.max_power_iter,
                                    dtype=self.dtype,
                                    device=self.device,
                                    show_pbar=self.show_pbar).run()
            self.dcf /= max_eig

            # Estimate scaling.
            img_adj = 0
            dcf = sp.to_device(self.dcf, self.device)
            coord = sp.to_device(self.coord, self.device)
            for c in range(self.C):
                mps_c = sp.to_device(self.mps[c], self.device)
                ksp_c = sp.to_device(self.ksp[c], self.device)
                img_adj_c = sp.nufft_adjoint(ksp_c * dcf, coord,
                                             self.img_shape)
                img_adj_c *= self.xp.conj(mps_c)
                img_adj += img_adj_c

            if self.comm is not None:
                self.comm.allreduce(img_adj)

            img_adj_norm = self.xp.linalg.norm(img_adj).item()
            self.ksp /= img_adj_norm
    def run(self):
        with self.device:
            self._init_vars()
            self._power_method()
            self.L_init = []
            self.R_init = []
            for j in range(self.J):
                self.L_init.append(sp.to_device(self.L[j]))
                self.R_init.append(sp.to_device(self.R[j]))

            done = False
            while not done:
                try:
                    self.L = []
                    self.R = []
                    for j in range(self.J):
                        self.L.append(sp.to_device(self.L_init[j],
                                                   self.device))
                        self.R.append(sp.to_device(self.R_init[j],
                                                   self.device))

                    self._sgd()
                    done = True
                except OverflowError:
                    self.alpha *= self.beta
                    if self.show_pbar:
                        tqdm.write('\nReconstruction diverged. '
                                   'Scaling step-size by {}.'.format(
                                       self.beta))

            if self.comm is None or self.comm.rank == 0:
                return MultiScaleLowRankImage(
                    (self.T, ) + self.img_shape,
                    [sp.to_device(L_j, sp.cpu_device) for L_j in self.L],
                    [sp.to_device(R_j, sp.cpu_device) for R_j in self.R])