コード例 #1
0
ファイル: conventional.py プロジェクト: zhenchen16/adorym
def multidistance_ctf_wrapped(this_prj_batch,
                              free_prop_cm,
                              energy_ev,
                              psize_cm,
                              kappa=50,
                              safe_zone_width=0,
                              prj_affine_ls=None,
                              device=None):

    u_free, v_free = gen_freq_mesh(
        np.array([psize_cm * 1e7] * 3),
        [this_prj_batch.shape[i + 1] + 2 * safe_zone_width for i in range(2)])
    u_free = w.create_variable(u_free, requires_grad=False, device=device)
    v_free = w.create_variable(v_free, requires_grad=False, device=device)

    this_prj_batch = w.create_variable(this_prj_batch,
                                       requires_grad=False,
                                       device=device)
    if prj_affine_ls is not None:
        for i in range(len(prj_affine_ls)):
            this_prj_batch[i] = w.affine_transform(this_prj_batch[i:i + 1],
                                                   prj_affine_ls[i])
    if safe_zone_width > 0:
        this_prj_batch = w.pad(this_prj_batch,
                               [(0, 0), (safe_zone_width, safe_zone_width),
                                (safe_zone_width, safe_zone_width)],
                               mode='edge')
    this_prj_batch_ft_r, this_prj_batch_ft_i = w.fft2(this_prj_batch - 1,
                                                      w.zeros_like(
                                                          this_prj_batch,
                                                          requires_grad=False,
                                                          device=device),
                                                      normalize=True)
    dist_nm_ls = free_prop_cm * 1e7
    prj_real_ls = []
    prj_imag_ls = []
    lmbda_nm = 1240. / energy_ev

    for i in range(len(dist_nm_ls)):
        xi = PI * lmbda_nm * dist_nm_ls[i] * (u_free**2 + v_free**2)
        prj_real_ls.append(
            (w.sin(xi) + 1. / kappa * w.cos(xi)) * this_prj_batch_ft_r[i])
        prj_imag_ls.append(
            (w.sin(xi) + 1. / kappa * w.cos(xi)) * this_prj_batch_ft_i[i])
    this_prj_batch_ft_r = w.sum(w.stack(prj_real_ls), axis=0)
    this_prj_batch_ft_i = w.sum(w.stack(prj_imag_ls), axis=0)

    osc_ls = []
    for i in range(len(dist_nm_ls)):
        xi = PI * lmbda_nm * dist_nm_ls[i] * (u_free**2 + v_free**2)
        osc_ls.append(2 * (w.sin(xi) + 1. / kappa * w.cos(xi))**2)
    osc = w.sum(w.stack(osc_ls), axis=0) + 1e-10

    a_real = this_prj_batch_ft_r / osc
    a_imag = this_prj_batch_ft_i / osc
    phase, _ = w.ifft2(a_real, a_imag, normalize=True)

    return phase[safe_zone_width:phase.shape[0] - safe_zone_width,
                 safe_zone_width:phase.shape[1] - safe_zone_width]
コード例 #2
0
def modulate_and_get_ctf(grid_batch, energy_ev, free_prop_cm, u_free=None, v_free=None, kappa=50.):

    lmbda_nm = 1240. / energy_ev
    dist_nm = free_prop_cm * 1e7

    p = w.sum(grid_batch, axis=-2)
    delta_slice = p[:, :, :, 0]
    beta_slice = p[:, :, :, 1]
    probe_real, probe_imag = pure_phase_ctf(u_free, v_free, delta_slice, beta_slice, dist_nm, lmbda_nm, kappa=kappa)
    return probe_real, probe_imag
コード例 #3
0
ファイル: linesearch.py プロジェクト: zhenchen16/adorym
    def search(self,
               objective_and_update: Callable,
               x0,
               descent_dir,
               gradient,
               f0=None):

        if f0 is None:
            f0, _ = objective_and_update(x0, w.zeros_like(x0))

        # Calculating the directional derivative along the descent direction
        descent_norm = w.vec_norm(descent_dir)
        df0 = w.sum(descent_dir * gradient)

        if self._oldf0 >= f0:
            # Pick initial step size based on where we were last time
            alpha = 2 * (f0 - self._oldf0) / df0

            # Look a little further
            alpha *= self.optimism
            if alpha * descent_norm < self._machine_eps:
                if self.normalize_alpha:
                    alpha = self.initial_stepsize / descent_norm
                else:
                    alpha = self.initial_stepsize
        else:
            if self.normalize_alpha:
                alpha = self.initial_stepsize / descent_norm
            else:
                alpha = self.initial_stepsize

        # Make the chosen sten and compute the cost there
        newf, newx = objective_and_update(x0, alpha * descent_dir)
        step_count = 1

        # Backtrack while the Armijo criterion is not satisfied
        def _cond(state: LSState):
            cond1 = state.newf > f0 + self.suff_decr * state.alpha * df0
            cond2 = state.step_count <= self.maxiter
            cond3 = state.alpha > self.stepsize_threshold_low
            return cond1 and cond2 and cond3

        lsstate_new = LSState(newf=newf,
                              newx=newx,
                              alpha=alpha,
                              step_count=step_count)
        while _cond(lsstate_new):
            alpha = self.contraction_factor * lsstate_new.alpha
            newf, newx = objective_and_update(x0, alpha * descent_dir)
            lsstate_new = LSState(newf=newf,
                                  newx=newx,
                                  alpha=alpha,
                                  step_count=lsstate_new.step_count + 1)

        self._oldf0 = f0
        self._alpha = lsstate_new.alpha
        if lsstate_new.newf <= f0:
            lsstate_updated = lsstate_new
        else:
            lsstate_updated = LSState(newf=f0,
                                      newx=x0,
                                      alpha=0.,
                                      step_count=lsstate_new.step_count)

        return lsstate_updated
コード例 #4
0
ファイル: linesearch.py プロジェクト: zhenchen16/adorym
    def search(self,
               objective_and_update: Callable,
               x0,
               descent_dir,
               gradient,
               f0=None):

        if f0 is None:
            f0, _ = objective_and_update(x0, w.zeros_like(x0))

        # Calculating the directional derivative along the descent direction
        descent_norm = w.vec_norm(descent_dir)
        df0 = w.sum(descent_dir * gradient)

        if self._alpha_suggested > 0:
            alpha = self._alpha_suggested
        else:
            if self.normalize_alpha:
                alpha = self.initial_stepsize / descent_norm
            else:
                alpha = self.initial_stepsize

        # Make the chosen sten and compute the cost there
        newf, newx = objective_and_update(x0, alpha * descent_dir)
        step_count = 1

        # Backtrack while the Armijo criterion is not satisfied
        def _cond(state: LSState):
            cond1 = state.newf > f0 + self.suff_decr * state.alpha * df0
            cond2 = (state.step_count <= self.maxiter)
            cond3 = state.alpha > self.stepsize_threshold_low
            return cond1 and cond2 and cond3

        lsstate_new = LSState(newf=newf,
                              newx=newx,
                              alpha=alpha,
                              step_count=step_count)
        while _cond(lsstate_new):
            alpha = self.contraction_factor * lsstate_new.alpha
            newf, newx = objective_and_update(x0, alpha * descent_dir)
            lsstate_new = LSState(newf=newf,
                                  newx=newx,
                                  alpha=alpha,
                                  step_count=lsstate_new.step_count + 1)

        # New suggestion for step size
        if lsstate_new.step_count - 1 == 0:
            # case 1: if things go very well (step count is 1), push your luck
            suggested_alpha = self.optimism * lsstate_new.alpha
        elif lsstate_new.step_count - 1 == 1:
            # case 2: if things go reasonably well (step count is 2), try to keep pace
            suggested_alpha = lsstate_new.alpha
        else:
            # case 3: if we backtracked a lot, the new stepsize is probably quite small:
            # try to recover
            suggested_alpha = self.optimism * lsstate_new.alpha

        self._alpha_suggested = suggested_alpha
        self._alpha = lsstate_new.alpha

        if lsstate_new.newf <= f0:
            lsstate_updated = lsstate_new
        else:
            print('Line search is unable to find a smaller loss ({} > {})!'.
                  format(lsstate_new.newf, f0))
            lsstate_updated = LSState(newf=f0,
                                      newx=x0,
                                      alpha=0.,
                                      step_count=lsstate_new.step_count)

        return lsstate_updated
コード例 #5
0
def multislice_backpropagate_batch(grid_batch,
                                   probe_real,
                                   probe_imag,
                                   energy_ev,
                                   psize_cm,
                                   delta_cm=None,
                                   free_prop_cm=None,
                                   obj_batch_shape=None,
                                   kernel=None,
                                   fresnel_approx=True,
                                   pure_projection=False,
                                   binning=1,
                                   device=None,
                                   type='delta_beta',
                                   normalize_fft=False,
                                   sign_convention=1,
                                   optimize_free_prop=False,
                                   u_free=None,
                                   v_free=None,
                                   scale_ri_by_k=True,
                                   is_minus_logged=False,
                                   pure_projection_return_sqrt=False,
                                   kappa=None,
                                   repeating_slice=None,
                                   return_fft_time=False,
                                   shift_exit_wave=None,
                                   return_intermediate_wavefields=False):

    intermediate_wavefield_real_ls = []
    intermediate_wavefield_imag_ls = []
    minibatch_size = grid_batch.shape[0]
    grid_shape = grid_batch.shape[1:-1]
    if delta_cm is not None:
        voxel_nm = np.array([psize_cm, psize_cm, delta_cm]) * 1.e7
    else:
        voxel_nm = np.array([psize_cm] * 3) * 1.e7

    lmbda_nm = 1240. / energy_ev
    mean_voxel_nm = np.prod(voxel_nm)**(1. / 3)
    size_nm = np.array(grid_shape) * voxel_nm

    n_slices = grid_batch.shape[-2]
    delta_nm = voxel_nm[-1]

    if repeating_slice is not None:
        n_slices = repeating_slice

    if pure_projection:
        k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1.
        if type == 'delta_beta':
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            p = w.sum(grid_batch, axis=-2)
            delta_slice = p[:, :, :, 0]
            if kappa is not None:
                beta_slice = delta_slice * kappa
            else:
                beta_slice = p[:, :, :, 1]
            # In conventional tomography beta is interpreted as mu. If projection data is minus-logged,
            # the line sum of beta (mu) directly equals image intensity. If raw_data_type is set to 'intensity',
            # measured data will be taken square root at the loss calculation step. To match this, the summed
            # beta must be square-rooted as well. Otherwise, set raw_data_type to 'magnitude' to avoid square-rooting
            # the measured data, and skip sqrt to summed beta here accordingly.
            if is_minus_logged:
                if pure_projection_return_sqrt:
                    c_real, c_imag = w.sqrt(beta_slice +
                                            1e-10), delta_slice * 0
                else:
                    c_real, c_imag = beta_slice, delta_slice * 0
            else:
                # exp(-ikn*)
                c_real, c_imag = w.exp_complex(
                    -k1 * beta_slice, sign_convention * k1 * delta_slice)
        elif type == 'real_imag':
            raise NotImplementedError('Backprop not done for real_imag.')
            p = w.prod(grid_batch, axis=-2)
            delta_slice = p[:, :, :, 0]
            beta_slice = p[:, :, :, 1]
            c_real, c_imag = delta_slice, beta_slice
            if is_minus_logged:
                if pure_projection_return_sqrt:
                    c_real, c_imag = w.sqrt(-w.log(c_real**2 + c_imag**2) +
                                            1e-10), 0
                else:
                    c_real, c_imag = -w.log(c_real**2 + c_imag**2), 0
        else:
            raise ValueError('unknown_type must be real_imag or delta_beta.')
        probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag,
                                  probe_real * c_imag + probe_imag * c_real)

    else:
        if kernel is not None:
            h = kernel
        else:
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            # Negative distance for backpropagation.
            h = get_kernel(-delta_nm * binning,
                           lmbda_nm,
                           voxel_nm,
                           grid_shape,
                           fresnel_approx=fresnel_approx,
                           sign_convention=sign_convention)
        h_real, h_imag = np.real(h), np.imag(h)
        h_real = w.create_variable(h_real, requires_grad=False, device=device)
        h_imag = w.create_variable(h_imag, requires_grad=False, device=device)

        t_tot = 0
        n_steps = int(np.ceil(n_slices / binning))
        i_slice = n_slices
        for i_step in range(n_steps):
            if return_intermediate_wavefields:
                intermediate_wavefield_real_ls.append(probe_real)
                intermediate_wavefield_imag_ls.append(probe_imag)
            # ==========================================
            # Sampling
            # ==========================================
            k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1.
            # At the start of bin, initialize slice array.
            if i_step == 0:
                this_step = n_slices % binning if n_slices % binning != 0 else binning
            else:
                this_step = binning
            if repeating_slice is None:
                if this_step > 1:
                    delta_slice = grid_batch[:, :, :,
                                             i_slice - this_step:i_slice, 0]
                else:
                    delta_slice = grid_batch[:, :, :, i_slice - 1, 0]
            else:
                delta_slice = grid_batch[:, :, :, 0:1, 0]
            if kappa is not None:
                # In sign = +1 convention, phase (delta) should be positive, and kappa is positive too.
                beta_slice = delta_slice * kappa
            else:
                if repeating_slice is None:
                    if this_step > 1:
                        beta_slice = grid_batch[:, :, :,
                                                i_slice - this_step:i_slice, 1]
                    else:
                        beta_slice = grid_batch[:, :, :, i_slice - 1, 1]
                else:
                    beta_slice = grid_batch[:, :, :, 0:1, 1]
            t0 = time.time()
            # ==========================================
            # Modulation
            # ==========================================
            if type == 'delta_beta':
                # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
                # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
                if this_step > 1:
                    delta_slice = w.sum(delta_slice, axis=3)
                    beta_slice = w.sum(beta_slice, axis=3)
                # exp(-ikn*)
                c_real, c_imag = w.exp_complex(
                    -k1 * beta_slice, sign_convention * k1 * delta_slice)
            elif type == 'real_imag':
                raise NotImplementedError('Backprop not done for real_imag.')
                if this_step > 1:
                    delta_slice = w.prod(delta_slice, axis=3)
                    beta_slice = w.prod(beta_slice, axis=3)
                c_real, c_imag = delta_slice, beta_slice
            else:
                raise ValueError(
                    'unknown_type must be delta_beta or real_imag.')
            probe_real, probe_imag = (probe_real * c_real -
                                      probe_imag * c_imag,
                                      probe_real * c_imag +
                                      probe_imag * c_real)
            # ==========================================
            # When arriving at the last slice of bin or object, do (back)propagation.
            # ==========================================
            if i_step < n_steps - 1:
                # Backpropagate over -z
                if this_step == binning:
                    probe_real, probe_imag = w.convolve_with_transfer_function(
                        probe_real, probe_imag, h_real, h_imag)
                else:
                    probe_real, probe_imag = fresnel_propagate(
                        probe_real,
                        probe_imag,
                        -delta_nm * this_step,
                        lmbda_nm,
                        voxel_nm,
                        device=device,
                        sign_convention=sign_convention)
            i_slice -= this_step
            t_tot += (time.time() - t0)

    if shift_exit_wave is not None:
        probe_real, probe_imag = realign_image_fourier(probe_real,
                                                       probe_imag,
                                                       shift_exit_wave,
                                                       axes=(1, 2),
                                                       device=device)

    if free_prop_cm not in [0, None]:
        if isinstance(free_prop_cm, str) and free_prop_cm == 'inf':
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            if sign_convention == 1:
                probe_real, probe_imag = w.fft2_and_shift(
                    probe_real,
                    probe_imag,
                    axes=[1, 2],
                    normalize=normalize_fft)
            else:
                probe_real, probe_imag = w.ifft2_and_shift(
                    probe_real,
                    probe_imag,
                    axes=[1, 2],
                    normalize=normalize_fft)
        else:
            dist_nm = free_prop_cm * 1e7
            l = np.prod(size_nm)**(1. / 3)
            if optimize_free_prop:
                probe_real, probe_imag = fresnel_propagate_wrapped(
                    u_free,
                    v_free,
                    probe_real,
                    probe_imag,
                    dist_nm,
                    lmbda_nm,
                    voxel_nm,
                    device=device,
                    sign_convention=sign_convention)
            elif not optimize_free_prop:
                probe_real, probe_imag = fresnel_propagate(
                    probe_real,
                    probe_imag,
                    dist_nm,
                    lmbda_nm,
                    voxel_nm,
                    device=device,
                    sign_convention=sign_convention)
    return_ls = [probe_real, probe_imag]
    if return_fft_time:
        return_ls.append(t_tot)
    if return_intermediate_wavefields:
        # intermediate_wavefield_real_ls = w.stack(intermediate_wavefield_real_ls)
        # intermediate_wavefield_imag_ls = w.stack(intermediate_wavefield_imag_ls)
        return_ls = return_ls + [
            intermediate_wavefield_real_ls, intermediate_wavefield_imag_ls
        ]
    return return_ls
コード例 #6
0
def multislice_propagate_batch(grid_batch, probe_real, probe_imag, energy_ev, psize_cm, delta_cm=None,
                               free_prop_cm=None, obj_batch_shape=None, kernel=None, fresnel_approx=True,
                               pure_projection=False, binning=1, device=None, type='delta_beta',
                               normalize_fft=False, sign_convention=1, optimize_free_prop=False, u_free=None, v_free=None,
                               scale_ri_by_k=True, is_minus_logged=False, pure_projection_return_sqrt=False,
                               kappa=None, repeating_slice=None, return_fft_time=False):

    minibatch_size = grid_batch.shape[0]
    grid_shape = grid_batch.shape[1:-1]
    if delta_cm is not None:
        voxel_nm = np.array([psize_cm, psize_cm, delta_cm]) * 1.e7
    else:
        voxel_nm = np.array([psize_cm] * 3) * 1.e7

    lmbda_nm = 1240. / energy_ev
    mean_voxel_nm = np.prod(voxel_nm) ** (1. / 3)
    size_nm = np.array(grid_shape) * voxel_nm

    n_slices = grid_batch.shape[-2]
    delta_nm = voxel_nm[-1]

    if repeating_slice is not None:
        n_slices = repeating_slice

    if pure_projection:
        k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1.
        if type == 'delta_beta':
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            p = w.sum(grid_batch, axis=-2)
            delta_slice = p[:, :, :, 0]
            if kappa is not None:
                beta_slice = delta_slice * kappa
            else:
                beta_slice = p[:, :, :, 1]
            # In conventional tomography beta is interpreted as mu. If projection data is minus-logged,
            # the line sum of beta (mu) directly equals image intensity. If raw_data_type is set to 'intensity',
            # measured data will be taken square root at the loss calculation step. To match this, the summed
            # beta must be square-rooted as well. Otherwise, set raw_data_type to 'magnitude' to avoid square-rooting
            # the measured data, and skip sqrt to summed beta here accordingly.
            if is_minus_logged:
                if pure_projection_return_sqrt:
                    c_real, c_imag = w.sqrt(beta_slice + 1e-10), delta_slice * 0
                else:
                    c_real, c_imag = beta_slice, delta_slice * 0
            else:
                c_real, c_imag = w.exp_complex(-k1 * beta_slice, -sign_convention * k1 * delta_slice)
        elif type == 'real_imag':
            p = w.prod(grid_batch, axis=-2)
            delta_slice = p[:, :, :, 0]
            beta_slice = p[:, :, :, 1]
            c_real, c_imag = delta_slice, beta_slice
            if is_minus_logged:
                if pure_projection_return_sqrt:
                    c_real, c_imag = w.sqrt(-w.log(c_real ** 2 + c_imag ** 2) + 1e-10), 0
                else:
                    c_real, c_imag = -w.log(c_real ** 2 + c_imag ** 2), 0
        else:
            raise ValueError('unknown_type must be real_imag or delta_beta.')
        probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real)

    else:
        if kernel is not None:
            h = kernel
        else:
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            h = get_kernel(delta_nm * binning, lmbda_nm, voxel_nm, grid_shape, fresnel_approx=fresnel_approx, sign_convention=sign_convention)
        h_real, h_imag = np.real(h), np.imag(h)
        h_real = w.create_variable(h_real, requires_grad=False, device=device)
        h_imag = w.create_variable(h_imag, requires_grad=False, device=device)

        i_bin = 0
        t_tot = 0
        for i in range(n_slices):
            k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1.
            # At the start of bin, initialize slice array.
            if repeating_slice is None:
                delta_slice = grid_batch[:, :, :, i, 0]
            else:
                delta_slice = grid_batch[:, :, :, 0, 0]
            if kappa is not None:
                # In sign = +1 convention, phase (delta) should be positive, and kappa is positive too.
                beta_slice = delta_slice * kappa
            else:
                if repeating_slice is None:
                    beta_slice = grid_batch[:, :, :, i, 1]
                else:
                    beta_slice = grid_batch[:, :, :, 0, 1]
            t0 = time.time()
            if type == 'delta_beta':
                # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
                # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
                c_real, c_imag = w.exp_complex(-k1 * beta_slice, -sign_convention * k1 * delta_slice)
            elif type == 'real_imag':
                c_real, c_imag = delta_slice, beta_slice
            else:
                raise ValueError('unknown_type must be delta_beta or real_imag.')
            probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real)
            i_bin += 1

            # When arriving at the last slice of bin or object, do propagation.
            if i_bin == binning or i == n_slices - 1:
                if i < n_slices - 1:
                    if i_bin == binning:
                        probe_real, probe_imag = w.convolve_with_transfer_function(probe_real, probe_imag, h_real, h_imag)
                    else:
                        probe_real, probe_imag = fresnel_propagate(probe_real, probe_imag, delta_nm * i_bin, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention)
                i_bin = 0
            t_tot += (time.time() - t0)

    if free_prop_cm not in [0, None]:
        if isinstance(free_prop_cm, str) and free_prop_cm == 'inf':
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            if sign_convention == 1:
                probe_real, probe_imag = w.fft2_and_shift(probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft)
            else:
                probe_real, probe_imag = w.ifft2_and_shift(probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft)
        else:
            dist_nm = free_prop_cm * 1e7
            l = np.prod(size_nm)**(1. / 3)
            if optimize_free_prop:
                    probe_real, probe_imag = fresnel_propagate_wrapped(u_free, v_free, probe_real, probe_imag, dist_nm,
                                                                       lmbda_nm, voxel_nm,
                                                                       device=device, sign_convention=sign_convention)
            elif not optimize_free_prop:
                probe_real, probe_imag = fresnel_propagate(probe_real, probe_imag, dist_nm, lmbda_nm, voxel_nm,
                                                           device=device, sign_convention=sign_convention)
    if return_fft_time:
        return probe_real, probe_imag, t_tot
    else:
        return probe_real, probe_imag