def _AHyH_L(self, t):
        # Download k-space arrays.
        tr_start = t * self.tr_per_frame
        tr_end = (t + 1) * self.tr_per_frame
        coord_t = sp.to_device(self.coord[tr_start:tr_end], self.device)
        dcf_t = sp.to_device(self.dcf[tr_start:tr_end], self.device)
        ksp_t = sp.to_device(self.ksp[:, tr_start:tr_end], self.device)

        # A^H(y_t)
        AHy_t = 0
        for c in range(self.C):
            mps_c = sp.to_device(self.mps[c], self.device)
            AHy_tc = sp.nufft_adjoint(dcf_t * ksp_t[c],
                                      coord_t,
                                      oshape=self.img_shape)
            AHy_tc *= self.xp.conj(mps_c)
            AHy_t += AHy_tc

        if self.comm is not None:
            self.comm.allreduce(AHy_t)

        for j in range(self.J):
            AHy_tj = self.B[j].H(AHy_t)
            self.R[j][t] = self.xp.sum(AHy_tj * self.xp.conj(self.L[j]),
                                       axis=range(-self.D, 0),
                                       keepdims=True)
    def gradf(self, mrimg):
        out = self.xp.zeros_like(mrimg)
        for b in range(self.B):
            for c in range(self.C):
                mps_c = sp.to_device(self.mps[c], self.device)
                out[b] += sp.nufft_adjoint(
                    self.bdcf[b] *
                    (sp.nufft(mrimg[b] * mps_c, self.bcoord[b]) -
                     self.bksp[b][c]),
                    self.bcoord[b],
                    oshape=mrimg.shape[1:]) * self.xp.conj(mps_c)

        if self.comm is not None:
            self.comm.allreduce(out)

        eps = 1e-31
        for b in range(self.B):
            if b > 0:
                diff = mrimg[b] - mrimg[b - 1]
                sp.axpy(out[b], self.lamda, diff / (self.xp.abs(diff) + eps))

            if b < self.B - 1:
                diff = mrimg[b] - mrimg[b + 1]
                sp.axpy(out[b], self.lamda, diff / (self.xp.abs(diff) + eps))

        return out
    def _normalize(self):
        with self.device:
            # Estimate maximum eigenvalue.
            coord_t = sp.to_device(self.coord[:self.tr_per_frame], self.device)
            dcf_t = sp.to_device(self.dcf[:self.tr_per_frame], self.device)
            F = sp.linop.NUFFT(self.img_shape, coord_t)
            W = sp.linop.Multiply(F.oshape, dcf_t)

            max_eig = sp.app.MaxEig(F.H * W * F,
                                    max_iter=self.max_power_iter,
                                    dtype=self.dtype,
                                    device=self.device,
                                    show_pbar=self.show_pbar).run()
            self.dcf /= max_eig

            # Estimate scaling.
            img_adj = 0
            dcf = sp.to_device(self.dcf, self.device)
            coord = sp.to_device(self.coord, self.device)
            for c in range(self.C):
                mps_c = sp.to_device(self.mps[c], self.device)
                ksp_c = sp.to_device(self.ksp[c], self.device)
                img_adj_c = sp.nufft_adjoint(ksp_c * dcf, coord,
                                             self.img_shape)
                img_adj_c *= self.xp.conj(mps_c)
                img_adj += img_adj_c

            if self.comm is not None:
                self.comm.allreduce(img_adj)

            img_adj_norm = self.xp.linalg.norm(img_adj).item()
            self.ksp /= img_adj_norm
    def _normalize(self):
        # Normalize using first phase.
        with device:
            mrimg_adj = 0
            for c in range(self.C):
                mrimg_c = sp.nufft_adjoint(self.bksp[0][c] * self.bdcf[0],
                                           self.bcoord[0], self.img_shape)
                mrimg_c *= self.xp.conj(sp.to_device(mps[c], device))
                mrimg_adj += mrimg_c

            if comm is not None:
                comm.allreduce(mrimg_adj)

            # Get maximum eigenvalue.
            F = sp.linop.NUFFT(self.img_shape, self.bcoord[0])
            W = sp.linop.Multiply(F.oshape, self.bdcf[0])
            max_eig = sp.app.MaxEig(F.H * W * F,
                                    max_iter=self.max_power_iter,
                                    dtype=ksp.dtype,
                                    device=device,
                                    show_pbar=self.show_pbar).run()

            # Normalize
            self.alpha /= max_eig
            self.lamda *= max_eig * self.xp.abs(mrimg_adj).max().item()
Beispiel #5
0
def autofov(ksp, coord, dcf, num_ro=100, device=sp.cpu_device, thresh=0.1):
    """Automatic estimation of FOV.

    FOV is estimated by thresholding a low resolution gridded image.
    coord will be modified in-place.

    Args:
        ksp (array): k-space measurements of shape (C, num_tr, num_ro, D).
            where C is the number of channels,
            num_tr is the number of TRs, num_ro is the readout points,
            and D is the number of spatial dimensions.
        coord (array): k-space coordinates of shape (num_tr, num_ro, D).
        dcf (array): density compensation factor of shape (num_tr, num_ro).
        num_ro (int): number of read-out points.
        device (Device): computing device.
        thresh (float): threshold between 0 and 1.

    """
    device = sp.Device(device)
    xp = device.xp
    with device:
        kspc = ksp[:, :, :num_ro]
        coordc = coord[:, :num_ro, :]
        dcfc = dcf[:, :num_ro]
        coordc2 = sp.to_device(coordc * 2, device)

        num_coils = len(kspc)
        imgc_shape = np.array(sp.estimate_shape(coordc))
        imgc2_shape = sp.estimate_shape(coordc2)
        imgc2_center = [i // 2 for i in imgc2_shape]
        imgc2 = sp.nufft_adjoint(sp.to_device(dcfc * kspc, device), coordc2,
                                 [num_coils] + imgc2_shape)
        imgc2 = xp.sum(xp.abs(imgc2)**2, axis=0)**0.5
        if imgc2.ndim == 3:
            imgc2_cor = imgc2[:, imgc2.shape[1] // 2, :]
            thresh *= imgc2_cor.max()
        else:
            thresh *= imgc2.max()

        boxc = imgc2 > thresh

        boxc = sp.to_device(boxc)
        boxc_idx = np.nonzero(boxc)
        boxc_shape = np.array([
            int(np.abs(boxc_idx[i] - imgc2_center[i]).max()) * 2
            for i in range(imgc2.ndim)
        ])

        img_scale = boxc_shape / imgc_shape
        coord *= img_scale
Beispiel #6
0
def gridding_recon(ksp, coord, dcf, T=1, device=sp.cpu_device):
    """ Gridding reconstruction.

    Args:
        ksp (array): k-space measurements of shape (C, num_tr, num_ro, D).
            where C is the number of channels,
            num_tr is the number of TRs, num_ro is the readout points,
            and D is the number of spatial dimensions.
        coord (array): k-space coordinates of shape (num_tr, num_ro, D).
        dcf (array): density compensation factor of shape (num_tr, num_ro).
        mps (array): sensitivity maps of shape (C, N_D, ..., N_1).
            where (N_D, ..., N_1) represents the image shape.
        T (int): number of frames.

    Returns:
        img (array): image of shape (T, N_D, ..., N_1).
    """
    device = sp.Device(device)
    xp = device.xp
    num_coils, num_tr, num_ro = ksp.shape
    tr_per_frame = num_tr // T
    img_shape = sp.estimate_shape(coord)

    with device:
        img = []
        for t in range(T):
            tr_start = t * tr_per_frame
            tr_end = (t + 1) * tr_per_frame
            coord_t = sp.to_device(coord[tr_start:tr_end], device)
            dcf_t = sp.to_device(dcf[tr_start:tr_end], device)

            img_t = 0
            for c in range(num_coils):
                logging.info(f'Reconstructing time {t}, coil {c}')
                ksp_tc = sp.to_device(ksp[c, tr_start:tr_end, :], device)

                img_t += xp.abs(
                    sp.nufft_adjoint(ksp_tc * dcf_t, coord_t, img_shape))**2

            img_t = img_t**0.5
            img.append(sp.to_device(img_t))

    img = np.stack(img)
    return img
Beispiel #7
0
 def time_nufft_adjoint(self):
     y = sp.nufft_adjoint(self.x, self.coord)
Beispiel #8
0
def nufft_adj1(data,
               traj,
               dcf,
               device=sp.Device(-1),
               smap=None,
               batch=40000,
               id_channel=False,
               ishape=None):
    xp = device.xp
    N_phase = data.shape[0]
    img = []
    for num_ph in range(N_phase):
        ksp_t = data[num_ph]
        coord_t = traj[num_ph]
        dcf_t = dcf[num_ph]

        if smap is not None:
            img_shape = list(smap.shape[-3:])
        elif ishape is not None:
            img_shape = ishape
        else:
            img_shape = sp.estimate_shape(coord_t)

        num_coils, num_tr, num_ro = ksp_t.shape
        ndim = coord_t.shape[-1]
        if id_channel is True:
            img_t = np.zeros((num_coils, ) + tuple(img_shape),
                             dtype=np.complex64)
        else:
            img_t = 0

        with device:
            for c in range(num_coils):
                img_tt = 0
                for seg in range((num_tr - 1) // batch + 1):
                    ksp_ttc = sp.to_device(
                        ksp_t[c, seg * batch:np.minimum((seg + 1) *
                                                        batch, num_tr), ...],
                        device)
                    coord_tt = sp.to_device(
                        coord_t[seg * batch:np.minimum((seg + 1) *
                                                       batch, num_tr), ...],
                        device)
                    dcf_tt = sp.to_device(
                        dcf_t[seg * batch:np.minimum((seg + 1) *
                                                     batch, num_tr), ...],
                        device)
                    img_tt += sp.nufft_adjoint(ksp_ttc * dcf_tt, coord_tt,
                                               img_shape)
                # TODO smap
                if id_channel is True:
                    img_t[c, ...] = sp.to_device(img_tt)
                else:
                    if smap is None:
                        img_t += xp.abs(img_tt)**2
                    else:
                        img_t += sp.to_device(
                            img_tt *
                            xp.conj(sp.to_device(smap[c, ...], device)))

            img_t = sp.to_device(img_t)
        img.append(img_t)

    return np.asarray(img)
Beispiel #9
0
def circulant_precond(mps,
                      weights=None,
                      coord=None,
                      lamda=0,
                      device=sp.cpu_device):
    r"""Compute circulant preconditioner.

    Considers the optimization problem:

    .. math::
        \min_P \| A^H A - F P F^H  \|_2^2

    where A is the Sense operator,
    and F is a unitary Fourier transform operator.

    Args:
        mps (array): sensitivity maps of shape [num_coils] + image shape.
        weights (array): k-space weights.
        coord (array): k-space coordinates of shape [...] + [ndim].
        lamda (float): regularization.

    Returns:
        array: circulant preconditioner of image shape.

    """
    if coord is not None:
        coord = sp.to_device(coord, device)

    if weights is not None:
        weights = sp.to_device(weights, device)

    dtype = mps.dtype
    device = sp.Device(device)
    xp = device.xp

    mps_shape = list(mps.shape)
    img_shape = mps_shape[1:]
    img2_shape = [i * 2 for i in img_shape]
    ndim = len(img_shape)

    scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)**2
    with device:
        idx = (slice(None, None, 2), ) * ndim
        if coord is None:
            ones = xp.zeros(img2_shape, dtype=dtype)
            if weights is None:
                ones[idx] = 1
            else:
                ones[idx] = weights**0.5

            psf = sp.ifft(ones)
        else:
            coord2 = coord * 2
            ones = xp.ones(coord.shape[:-1], dtype=dtype)
            if weights is not None:
                ones *= weights**0.5

            psf = sp.nufft_adjoint(ones, coord2, img2_shape)

        p_inv = 0
        for mps_i in mps:
            mps_i = sp.to_device(mps_i, device)
            xcorr_fourier = xp.abs(sp.fft(xp.conj(mps_i), img2_shape))**2
            xcorr = sp.ifft(xcorr_fourier)
            xcorr *= psf
            p_inv_i = sp.fft(xcorr)
            p_inv_i = p_inv_i[idx]
            p_inv_i *= scale
            if weights is not None:
                p_inv_i *= weights**0.5

            p_inv += p_inv_i

        p_inv += lamda
        p_inv[p_inv == 0] = 1
        p = 1 / p_inv

        return p.astype(dtype)
Beispiel #10
0
def kspace_precond(mps,
                   weights=None,
                   coord=None,
                   lamda=0,
                   device=sp.cpu_device,
                   oversamp=1.25):
    r"""Compute a diagonal preconditioner in k-space.

    Considers the optimization problem:

    .. math::
        \min_P \| P A A^H - I \|_F^2

    where A is the Sense operator.

    Args:
        mps (array): sensitivity maps of shape [num_coils] + image shape.
        weights (array): k-space weights.
        coord (array): k-space coordinates of shape [...] + [ndim].
        lamda (float): regularization.

    Returns:
        array: k-space preconditioner of same shape as k-space.

    """
    dtype = mps.dtype

    if weights is not None:
        weights = sp.to_device(weights, device)

    device = sp.Device(device)
    xp = device.xp

    mps_shape = list(mps.shape)
    img_shape = mps_shape[1:]
    img2_shape = [i * 2 for i in img_shape]
    ndim = len(img_shape)

    scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)
    with device:
        if coord is None:
            idx = (slice(None, None, 2), ) * ndim

            ones = xp.zeros(img2_shape, dtype=dtype)
            if weights is None:
                ones[idx] = 1
            else:
                ones[idx] = weights**0.5

            psf = sp.ifft(ones)
        else:
            coord2 = coord * 2
            ones = xp.ones(coord.shape[:-1], dtype=dtype)
            if weights is not None:
                ones *= weights**0.5

            psf = sp.nufft_adjoint(ones, coord2, img2_shape, oversamp=oversamp)

        p_inv = []
        for mps_i in mps:
            mps_i = sp.to_device(mps_i, device)
            mps_i_norm2 = xp.linalg.norm(mps_i)**2
            xcorr_fourier = 0
            for mps_j in mps:
                mps_j = sp.to_device(mps_j, device)
                xcorr_fourier += xp.abs(
                    sp.fft(mps_i * xp.conj(mps_j), img2_shape))**2

            xcorr = sp.ifft(xcorr_fourier)
            xcorr *= psf
            if coord is None:
                p_inv_i = sp.fft(xcorr)[idx]
            else:
                p_inv_i = sp.nufft(xcorr, coord2, oversamp=oversamp)

            if weights is not None:
                p_inv_i *= weights**0.5

            p_inv.append(p_inv_i * scale / mps_i_norm2)

        p_inv = (xp.abs(xp.stack(p_inv, axis=0)) + lamda) / (1 + lamda)
        p_inv[p_inv == 0] = 1
        p = 1 / p_inv

        return p.astype(dtype)
    def _update(self, t):
        # Form image.
        img_t = 0
        for j in range(self.J):
            img_t += self.B[j](self.L[j] * self.R[j][t])

        # Download k-space arrays.
        tr_start = t * self.tr_per_frame
        tr_end = (t + 1) * self.tr_per_frame
        coord_t = sp.to_device(self.coord[tr_start:tr_end], self.device)
        dcf_t = sp.to_device(self.dcf[tr_start:tr_end], self.device)
        ksp_t = sp.to_device(self.ksp[:, tr_start:tr_end], self.device)

        # Data consistency.
        e_t = 0
        loss_t = 0
        for c in range(self.C):
            mps_c = sp.to_device(self.mps[c], self.device)
            e_tc = sp.nufft(img_t * mps_c, coord_t)
            e_tc -= ksp_t[c]
            e_tc *= dcf_t**0.5
            loss_t += self.xp.linalg.norm(e_tc)**2
            e_tc *= dcf_t**0.5
            e_tc = sp.nufft_adjoint(e_tc, coord_t, oshape=self.img_shape)
            e_tc *= self.xp.conj(mps_c)
            e_t += e_tc

        if self.comm is not None:
            self.comm.allreduce(e_t)
            self.comm.allreduce(loss_t)

        loss_t = loss_t.item()

        # Compute gradient.
        for j in range(self.J):
            lamda_j = self.lamda * self.G[j]

            # Loss.
            loss_t += lamda_j / self.T * self.xp.linalg.norm(
                self.L[j]).item()**2
            loss_t += lamda_j * self.xp.linalg.norm(self.R[j][t]).item()**2
            if np.isinf(loss_t) or np.isnan(loss_t):
                raise OverflowError

            # L gradient.
            g_L_j = self.B[j].H(e_t)
            g_L_j *= self.xp.conj(self.R[j][t])
            g_L_j += lamda_j / self.T * self.L[j]
            g_L_j *= self.T

            # R gradient.
            g_R_jt = self.B[j].H(e_t)
            g_R_jt *= self.xp.conj(self.L[j])
            g_R_jt = self.xp.sum(g_R_jt, axis=range(-self.D, 0), keepdims=True)
            g_R_jt += lamda_j * self.R[j][t]

            # Precondition.
            g_L_j /= self.J * self.sigma[j] + lamda_j
            g_R_jt /= self.J * self.sigma[j] + lamda_j

            # Add.
            self.L[j] -= self.alpha * self.beta**(self.epoch //
                                                  self.decay_epoch) * g_L_j
            self.R[j][t] -= self.alpha * g_R_jt

        loss_t /= 2
        return loss_t
ksp = xp.load(ksp_file)
coord = xp.load(coord_file)


def show_data_info(data, name):
    print("{}: shape={}, dtype={}".format(name, data.shape, data.dtype))


dcf = (coord[..., 0]**2 + coord[..., 1]**2)**0.5
pl.ScatterPlot(coord, dcf, title='Density compensation')

show_data_info(ksp, "ksp")
show_data_info(coord, "coord")
show_data_info(dcf, "dcf")

img_grid = sp.nufft_adjoint(ksp * dcf, coord)
pl.ImagePlot(img_grid, z=0, title='Multi-channel Gridding')

#%% md

## Estimate sensitivity maps using JSENSE

# Here we use [JSENSE](https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21245) to estimate sensitivity maps.

#%%

mps = mr.app.JsenseRecon(ksp, coord=coord, device=device).run()

#%% md

## CG
################################################################################
# Estimate coil maps using Walsh's method from temporal averaged data
#
with device:

    # assume spiral sampling patterns repeat every n_full_arms
    n_avr = int(xp.floor(ns / n_full_arms))

    kdata_avr = kdata[:n_avr * n_full_arms, :, :].reshape(
        n_avr, n_full_arms, nc, nk)
    kdata_avr = xp.mean(kdata_avr, axis=0)
    kdata_avr = xp.transpose(kdata_avr, (1, 0, 2))

    kloc_avr = kloc[:n_full_arms, :, :]

    avr_img = sp.nufft_adjoint(kdata_avr * kweight[xp.newaxis, xp.newaxis, :],
                               kloc_avr)

    # needs to process on CPUs
    sens_map = estimate_coilmap_walsh(sp.to_device(avr_img, -1),
                                      smoothing=20,
                                      thresh=0.0)

    # copy it to GPU
    sens_map = sp.to_device(sens_map, device_id)
    # pl.ImagePlot(xp.squeeze(avr_img), z=0, title='Multi-channel Time Averaged Image')
    # pl.ImagePlot(xp.squeeze(xp.abs(sens_map)), z=0, title='Walsh (Python)')

    # TODO Espirit coil map estimation needs to be improved

################################################################################
# Reshape Data
Beispiel #14
0
 def test_shepp_logan_dcf(self):
     img, coord, ksp = self.shepp_logan_setup()
     pm_dcf = dcf.pipe_menon_dcf(coord, show_pbar=False)
     img_dcf = sp.nufft_adjoint(ksp * pm_dcf, coord, oshape=img.shape)
     img_dcf /= np.abs(img_dcf).max()
     npt.assert_allclose(img, img_dcf, atol=1, rtol=1e-1)