Exemple #1
0
def calc_shims(shim_roi, sens, x0, dt, lamb=0, max_iter=50):
    """RF shim designer. Uses the Gerchberg Saxton algorithm.

     Args:
        shim_roi (array): region within volume to be shimmed. Mask of 1's and
            0's. [dim_x dim_y dim_z]
        sens (array): sensitivity maps. [Nc dim_x dim_y dim_z]
        x0 (array) initial guess for shim values. [Nc 1]
        dt (float): hardware sampling dwell time.
        lamb (float): regularization term.
        max_iter (int): max number of iterations.

     Returns:
         Vector of complex shim weights.
    """

    k1 = np.expand_dims(np.array((0, 0, 0)), 0)
    A = rf.PtxSpatialExplicit(sens, coord=k1, dt=dt,
                              img_shape=shim_roi.shape, ret_array=False)

    alg_method = sp.alg.GerchbergSaxton(A, shim_roi, x0, max_iter=max_iter,
                                        tol=10E-9, lamb=lamb)
    while not alg_method.done():
        alg_method.update()

    return alg_method.x
Exemple #2
0
    def test_stspa_2d_explicit(self):
        target, sens = self.problem_2d(8)
        dim = target.shape[0]
        g, k1, t, s = rf.spiral_arch(0.24, dim, 4e-6, 200, 0.035)
        k1 = k1 / dim

        A = rf.PtxSpatialExplicit(sens, k1, dt=4e-6, img_shape=target.shape,
                                  b0=None)
        pulses = sp.mri.rf.stspa(target, sens, st=None, coord=k1, dt=4e-6,
                                 max_iter=100, alpha=10, tol=1E-4,
                                 phase_update_interval=200, explicit=True)

        npt.assert_array_almost_equal(A*pulses, target, 1E-3)
Exemple #3
0
def stspk(mask,
          sens,
          n_spokes,
          fov,
          dx_max,
          gts,
          sl_thick,
          tbw,
          dgdtmax,
          gmax,
          alpha=1,
          iter_dif=0.01):
    """Small tip spokes and k-t points parallel transmit pulse designer.

       Args:
           mask (ndarray): region in which to optimize flip angle uniformity
               in slice. [dim dim]
           sens (ndarray): sensitivity maps. [nc dim dim]
           n_spokes (int): number of spokes to be created in the design.
           fov (float): excitation FOV (cm).
           dx_max (float): max. resolution of the trajectory (cm).
           gts (float): hardware sampling dwell time (s).
           sl_thick (float): slice thickness (mm).
           tbw (int): time-bandwidth product.
           dgdtmax (float): max gradient slew (g/cm/s).
           gmax (float): max gradient amplitude (g/cm).
           alpha (float): regularization parameter.
           iter_dif (float): for each spoke, the difference in cost btwn.
              successive iterations at which to terminate MLS iterations.

    Returns:
        2-element tuple containing

        - **pulses** (*array*): RF waveform out.
        - **g** (*array*): corresponding gradient, in g/cm.

       References:
           Grissom, W., Khalighi, M., Sacolick, L., Rutt, B. & Vogel, M (2012).
           Small-tip-angle spokes pulse design using interleaved greedy and
           local optimization methods. Magnetic Resonance in Medicine, 68(5),
           1553-62.
       """

    device = backend.get_device(sens)
    xp = device.xp
    with device:
        nc = sens.shape[0]

        kmax = 1 / dx_max  # /cm, max spatial freq of trajectory
        # greedy kx, ky grid
        kxs, kys = xp.meshgrid(
            xp.linspace(-kmax / 2, kmax / 2 - 1 / fov, xp.int(fov * kmax)),
            xp.linspace(-kmax / 2, kmax / 2 - 1 / fov, xp.int(fov * kmax)))
        # vectorize the grid
        kxs = kxs.flatten()
        kys = kys.flatten()

        # remove DC
        dc = xp.intersect1d(xp.where((kxs == 0)), xp.where((kys == 0)))[0]
        kxs = xp.concatenate([kxs[:dc], kxs[dc + 1:]])
        kys = xp.concatenate([kys[:dc], kys[dc + 1:]])

        # step 2: design the weights
        # initial kx/ky location is DC
        k = xp.expand_dims(xp.array([0, 0]), 0)

        # initial target phase
        phs = xp.zeros((xp.count_nonzero(mask), 1), dtype=xp.complex)

        for ii in range(n_spokes):

            # build Afull (and take only 0 locations into matrix)
            Anum = rf.PtxSpatialExplicit(sens,
                                         k,
                                         gts,
                                         mask.shape,
                                         ret_array=True)
            Anum = Anum[~(Anum == 0).all(1)]

            # design wfull using MLS:
            # initialize wfull
            sys_a = (Anum.conj().T @ Anum + alpha * xp.eye((ii + 1) * nc))
            sys_b = (Anum.conj().T @ xp.exp(1j * phs))
            w_full = xp.linalg.solve(sys_a, sys_b)

            err = Anum @ w_full - xp.exp(1j * phs)
            cost = err.conj().T @ err + alpha * w_full.conj().T @ w_full
            cost = xp.real(cost)
            cost_old = 10 * cost  # to get the loop going
            while xp.absolute(cost - cost_old) > iter_dif * cost_old:
                cost_old = cost
                phs = xp.angle(Anum @ w_full)
                w_full = xp.linalg.solve(
                    (Anum.conj().T @ Anum + alpha * xp.eye(
                        (ii + 1) * nc)), (Anum.conj().T @ xp.exp(1j * phs)))
                err = Anum @ w_full - xp.exp(1j * phs)
                cost = xp.real(err.conj().T @ err +
                               alpha * w_full.conj().T @ w_full)

            # add a spoke using greedy method
            if ii < n_spokes - 1:

                r = xp.exp(1j * phs) - Anum @ w_full
                rfnorm = xp.zeros(kxs.shape, dtype=xp.complex)
                for jj in range(kxs.size):
                    ks_test = xp.expand_dims(xp.array([kxs[jj], kys[jj]]), 0)
                    Anum = rf.PtxSpatialExplicit(sens,
                                                 ks_test,
                                                 gts,
                                                 mask.shape,
                                                 ret_array=True)
                    Anum = Anum[~(Anum == 0).all(1)]

                    rfm = xp.linalg.solve((Anum.conj().T @ Anum),
                                          (Anum.conj().T @ r))
                    rfnorm[jj] = xp.linalg.norm(rfm)

                ind = xp.argmax(rfnorm)
                k_new = xp.expand_dims(xp.array([kxs[ind], kys[ind]]), 0)

                if ii % 2 != 0:  # add to end of pulse
                    k = xp.concatenate((k, k_new))
                else:  # add to beginning of pulse
                    k = xp.concatenate((k_new, k))

                # remove chosen point from candidates
                kxs = xp.concatenate([kxs[:ind], kxs[ind + 1:]])
                kys = xp.concatenate([kys[:ind], kys[ind + 1:]])

        # from our spoke selections, build the whole waveforms

        # first, design our gradient waveforms:
        g = rf.spokes_grad(k, tbw, sl_thick, gmax, dgdtmax, gts)

        # design our rf
        # calc. the size of the traps in our gz waveform- will use to calc rf
        area = tbw / (sl_thick / 10) / 4257  # thick*kwid=twb, kwid=gam*area
        [subgz, nramp] = rf.min_trap_grad(area, gmax, dgdtmax, gts)
        npts = 128
        subrf = rf.dzrf(npts, tbw, 'st')

        n_plat = subgz.size - 2 * nramp  # time points on trap plateau
        # interpolate to stretch out waveform to appropriate length
        f = interp1d(np.arange(0, npts, 1) / npts,
                     subrf,
                     fill_value='extrapolate')
        subrf = f(xp.arange(0, n_plat, 1) / n_plat)
        subrf = xp.concatenate((xp.zeros(nramp), subrf, xp.zeros(nramp)))

        pulses = xp.kron(xp.reshape(w_full, (nc, n_spokes)), subrf)

        # add zeros for gzref
        rf_ref = xp.zeros((nc, g.shape[1] - pulses.shape[1]))
        pulses = xp.concatenate((pulses, rf_ref), 1)

        return pulses, g