Exemple #1
0
def pure_phase_ctf(u,
                   v,
                   delta_slice,
                   beta_slice,
                   dist_nm,
                   lmbda_nm,
                   kappa=50.,
                   alpha=1e-10,
                   override_backend=None):

    # Beware: CTF forward model is very sensitive to discontinuity. Delta maps with non-vacuum boundaries can
    # cause the result to blow up, which can't be solved even with edge-mode padding. Vignetting is the only
    # way through. Otherwise, use the result of NON-PADDED CTF phase retrieval (which contains vignetting
    # by itself) as the initial guess.
    print(kappa)
    probe_real, probe_imag = w.fft2(delta_slice,
                                    w.zeros_like(delta_slice,
                                                 requires_grad=False),
                                    override_backend=override_backend)
    xi = PI * lmbda_nm * dist_nm * (u**2 + v**2)
    osc = 2 * (w.sin(xi) + 1. / kappa * w.cos(xi))
    probe_real = osc * probe_real
    probe_imag = osc * probe_imag
    probe_real, probe_imag = w.ifft2(probe_real,
                                     probe_imag,
                                     override_backend=override_backend)
    probe_real = probe_real + 1
    probe_real = w.sqrt(w.clip(probe_real, 0, None))
    probe_imag = probe_imag * 0
    return probe_real, probe_imag
Exemple #2
0
 def get_value(self, obj, device=None, **kwargs):
     slicer = [slice(None)] * (len(obj.shape) - 1)
     reg = w.create_variable(0., device=device)
     if self.unknown_type == 'delta_beta':
         if self.alpha_d not in [None, 0]:
             reg = reg + self.alpha_d * w.mean(w.abs(obj[slicer + [0]]))
         if self.alpha_b not in [None, 0]:
             reg = reg + self.alpha_b * w.mean(w.abs(obj[slicer + [1]]))
     elif self.unknown_type == 'real_imag':
         r = obj[slicer + [0]]
         i = obj[slicer + [1]]
         if self.alpha_d not in [None, 0]:
             om = w.sqrt(r ** 2 + i ** 2)
             reg = reg + self.alpha_d * w.mean(w.abs(om - w.mean(om)))
         if self.alpha_b not in [None, 0]:
             reg = reg + self.alpha_b * w.mean(w.abs(w.arctan2(i, r)))
     return reg
Exemple #3
0
    def get_value(self, obj, distribution_mode=None, device=None, **kwargs):
        slicer = [slice(None)] * (len(obj.shape) - 1)
        reg = w.create_variable(0., device=device)
        if self.unknown_type == 'delta_beta':
            o1 = obj[slicer + [0]]
            o2 = obj[slicer + [1]]
        elif self.unknown_type == 'real_imag':
            r = obj[slicer + [0]]
            i = obj[slicer + [1]]
            o1 = w.sqrt(r ** 2 + i ** 2)
            o2 = w.arctan2(i, r)
        else:
            raise ValueError('Invalid value for unknown_type.')

        reg = reg + self.gamma * w.pcc(o1)
        reg = reg + self.gamma * w.pcc(o2)
        return reg
Exemple #4
0
    def apply_gradient(self,
                       x,
                       g,
                       i_batch,
                       step_size=0.001,
                       b1=0.9,
                       b2=0.999,
                       eps=1e-7,
                       distribution_mode=False,
                       m=None,
                       v=None,
                       return_moments=False,
                       update_batch_count=True,
                       **kwargs):

        if m is None or v is None:
            if distribution_mode == 'shared_file':
                m = self.params_chunk_array_dict['m']
                v = self.params_chunk_array_dict['v']
            else:
                m = self.params_whole_array_dict['m']
                v = self.params_whole_array_dict['v']
        m = (1 - b1) * g + b1 * m  # First moment estimate.
        v = (1 - b2) * (g**2) + b2 * v  # Second moment estimate.
        mhat = m / (1 - b1**(i_batch + 1))  # Bias correction.
        vhat = v / (1 - b2**(i_batch + 1))
        d = step_size * mhat / (w.sqrt(vhat) + eps)
        x = x - d
        if distribution_mode == 'shared_file':
            self.params_chunk_array_dict['m'] = m
            self.params_chunk_array_dict['v'] = v
        else:
            self.params_whole_array_dict['m'] = m
            self.params_whole_array_dict['v'] = v
        if update_batch_count:
            self.i_batch += 1
        del mhat, vhat
        if return_moments:
            return x, m, v
        else:
            return x
Exemple #5
0
def alt_reconstruction_epie(obj_real,
                            obj_imag,
                            probe_real,
                            probe_imag,
                            probe_pos,
                            probe_pos_correction,
                            prj,
                            device_obj=None,
                            minibatch_size=1,
                            alpha=1.,
                            n_epochs=100,
                            **kwargs):
    """
    Reconstruct a 2D object and probe function using ePIE.
    """
    with w.no_grad():
        probe_real = probe_real[0]
        probe_imag = probe_imag[0]
        probe_pos = probe_pos.astype(int)
        energy_ev = kwargs['energy_ev']
        psize_cm = kwargs['psize_cm']
        output_folder = kwargs['output_folder']
        raw_data_type = kwargs['raw_data_type']
        this_obj_size = obj_real.shape
        if len(probe_real) == 2:
            probe_size = probe_real.shape
        else:
            probe_size = probe_imag.shape
        obj_stack = w.stack([obj_real, obj_imag], axis=3)

        # Pad if needed
        obj_stack, pad_arr = pad_object(obj_stack,
                                        this_obj_size,
                                        probe_pos,
                                        probe_size,
                                        unknown_type='real_imag')

        i_batch = 0
        subobj_ls = []
        probe_real_ls = []
        probe_imag_ls = []
        for i_epoch in range(n_epochs):
            for j in range(len(probe_pos)):
                print('Batch {}/{}; Epoch {}/{}.'.format(
                    j, len(probe_pos), i_epoch, n_epochs))
                pos = probe_pos[j]
                pos[0] = pos[0] + pad_arr[0, 0]
                pos[1] = pos[1] + pad_arr[1, 0]
                subobj = obj_stack[pos[0]:pos[0] + probe_size[0],
                                   pos[1]:pos[1] + probe_size[1], :, :]
                subobj_ls.append(subobj)
                if len(w.nonzero(probe_pos_correction > 1e-3)) > 0:
                    this_shift = probe_pos_correction[0, j]
                    probe_real_shifted, probe_imag_shifted = realign_image_fourier(
                        probe_real,
                        probe_imag,
                        this_shift,
                        axes=(0, 1),
                        device=device_obj)
                else:
                    probe_real_shifted = probe_real
                    probe_imag_shifted = probe_imag
                probe_real_ls.append(probe_real_shifted)
                probe_imag_ls.append(probe_imag_shifted)
                i_batch += 1
                if i_batch < minibatch_size and i_batch < prj.shape[1]:
                    continue
                else:
                    this_prj_batch = prj[0, j *
                                         minibatch_size:j * minibatch_size +
                                         i_batch, :, :]
                    this_prj_batch = w.create_variable(this_prj_batch,
                                                       requires_grad=False,
                                                       device=device_obj)
                    if raw_data_type == 'intensity':
                        this_prj_batch = w.sqrt(this_prj_batch)
                    subobj_ls = w.stack(subobj_ls)
                    probe_real_ls = w.stack(probe_real_ls)
                    probe_imag_ls = w.stack(probe_imag_ls)
                    c_real, c_imag = subobj_ls[:, :, :, 0,
                                               0], subobj_ls[:, :, :, 0, 1]
                    ex_real_ls, ex_imag_ls = (probe_real_ls * c_real -
                                              probe_imag_ls * c_imag,
                                              probe_real_ls * c_imag +
                                              probe_imag_ls * c_real)
                    dp_real_ls, dp_imag_ls = w.fft2_and_shift(
                        ex_real_ls, ex_imag_ls)
                    mag_replace_factor = this_prj_batch / w.sqrt(
                        dp_real_ls**2 + dp_imag_ls**2)
                    dp_real_ls = dp_real_ls * mag_replace_factor
                    dp_imag_ls = dp_imag_ls * mag_replace_factor
                    phi_real_ls, phi_imag_ls = w.ishift_and_ifft2(
                        dp_real_ls, dp_imag_ls)
                    d_real_ls = phi_real_ls - ex_real_ls
                    d_imag_ls = phi_imag_ls - ex_imag_ls

                    norm = w.max(probe_real_ls**2 + probe_imag_ls**2)
                    o_up_real = (probe_real_ls * d_real_ls +
                                 probe_imag_ls * d_imag_ls) / norm
                    o_up_imag = (probe_real_ls * d_imag_ls -
                                 probe_imag_ls * d_real_ls) / norm
                    o_up = w.stack([o_up_real, o_up_imag], axis=-1)
                    o_up = w.reshape(
                        o_up, [i_batch, probe_size[0], probe_size[1], 1, 2])
                    subobj_ls = subobj_ls + alpha * o_up

                    norm = w.max(subobj_ls[:, :, :, 0, 0]**2 +
                                 subobj_ls[:, :, :, 0, 1]**2)
                    p_up_real = (subobj_ls[:, :, :, 0, 0] * d_real_ls +
                                 subobj_ls[:, :, :, 0, 1] * d_imag_ls) / norm
                    p_up_imag = (subobj_ls[:, :, :, 0, 0] * d_imag_ls -
                                 subobj_ls[:, :, :, 0, 1] * d_real_ls) / norm
                    p_up = w.stack([p_up_real, p_up_imag], axis=-1)
                    p_up = w.reshape(
                        p_up, [i_batch, probe_size[0], probe_size[1], 1, 2])
                    p_up = w.mean(p_up, axis=0)
                    probe_real = probe_real + alpha * p_up
                    probe_imag = probe_imag + alpha * p_up

                    # Put back.
                    for i, k in enumerate(
                            range(j * minibatch_size,
                                  j * minibatch_size + i_batch)):
                        pos = probe_pos[k]
                        pos[0] = pos[0] + pad_arr[0, 0]
                        pos[1] = pos[1] + pad_arr[1, 0]
                        obj_stack[pos[0]:pos[0] + probe_size[0],
                                  pos[1]:pos[1] +
                                  probe_size[1], :, :] = subobj_ls[i]

                    i_batch = 0
                    subobj_ls = []
                    probe_real_ls = []
                    probe_imag_ls = []

            fname0 = 'obj_mag_{}_{}'.format(i_epoch, i_batch)
            fname1 = 'obj_phase_{}_{}'.format(i_epoch, i_batch)
            obj0, obj1 = w.split_channel(obj_stack)
            obj0 = w.to_numpy(obj0)
            obj1 = w.to_numpy(obj1)
            dxchange.write_tiff(np.sqrt(obj0**2 + obj1**2),
                                os.path.join(output_folder, fname0),
                                dtype='float32',
                                overwrite=True)
            dxchange.write_tiff(np.arctan2(obj1, obj0),
                                os.path.join(output_folder, fname1),
                                dtype='float32',
                                overwrite=True)
Exemple #6
0
def multislice_backpropagate_batch(grid_batch,
                                   probe_real,
                                   probe_imag,
                                   energy_ev,
                                   psize_cm,
                                   delta_cm=None,
                                   free_prop_cm=None,
                                   obj_batch_shape=None,
                                   kernel=None,
                                   fresnel_approx=True,
                                   pure_projection=False,
                                   binning=1,
                                   device=None,
                                   type='delta_beta',
                                   normalize_fft=False,
                                   sign_convention=1,
                                   optimize_free_prop=False,
                                   u_free=None,
                                   v_free=None,
                                   scale_ri_by_k=True,
                                   is_minus_logged=False,
                                   pure_projection_return_sqrt=False,
                                   kappa=None,
                                   repeating_slice=None,
                                   return_fft_time=False,
                                   shift_exit_wave=None,
                                   return_intermediate_wavefields=False):

    intermediate_wavefield_real_ls = []
    intermediate_wavefield_imag_ls = []
    minibatch_size = grid_batch.shape[0]
    grid_shape = grid_batch.shape[1:-1]
    if delta_cm is not None:
        voxel_nm = np.array([psize_cm, psize_cm, delta_cm]) * 1.e7
    else:
        voxel_nm = np.array([psize_cm] * 3) * 1.e7

    lmbda_nm = 1240. / energy_ev
    mean_voxel_nm = np.prod(voxel_nm)**(1. / 3)
    size_nm = np.array(grid_shape) * voxel_nm

    n_slices = grid_batch.shape[-2]
    delta_nm = voxel_nm[-1]

    if repeating_slice is not None:
        n_slices = repeating_slice

    if pure_projection:
        k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1.
        if type == 'delta_beta':
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            p = w.sum(grid_batch, axis=-2)
            delta_slice = p[:, :, :, 0]
            if kappa is not None:
                beta_slice = delta_slice * kappa
            else:
                beta_slice = p[:, :, :, 1]
            # In conventional tomography beta is interpreted as mu. If projection data is minus-logged,
            # the line sum of beta (mu) directly equals image intensity. If raw_data_type is set to 'intensity',
            # measured data will be taken square root at the loss calculation step. To match this, the summed
            # beta must be square-rooted as well. Otherwise, set raw_data_type to 'magnitude' to avoid square-rooting
            # the measured data, and skip sqrt to summed beta here accordingly.
            if is_minus_logged:
                if pure_projection_return_sqrt:
                    c_real, c_imag = w.sqrt(beta_slice +
                                            1e-10), delta_slice * 0
                else:
                    c_real, c_imag = beta_slice, delta_slice * 0
            else:
                # exp(-ikn*)
                c_real, c_imag = w.exp_complex(
                    -k1 * beta_slice, sign_convention * k1 * delta_slice)
        elif type == 'real_imag':
            raise NotImplementedError('Backprop not done for real_imag.')
            p = w.prod(grid_batch, axis=-2)
            delta_slice = p[:, :, :, 0]
            beta_slice = p[:, :, :, 1]
            c_real, c_imag = delta_slice, beta_slice
            if is_minus_logged:
                if pure_projection_return_sqrt:
                    c_real, c_imag = w.sqrt(-w.log(c_real**2 + c_imag**2) +
                                            1e-10), 0
                else:
                    c_real, c_imag = -w.log(c_real**2 + c_imag**2), 0
        else:
            raise ValueError('unknown_type must be real_imag or delta_beta.')
        probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag,
                                  probe_real * c_imag + probe_imag * c_real)

    else:
        if kernel is not None:
            h = kernel
        else:
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            # Negative distance for backpropagation.
            h = get_kernel(-delta_nm * binning,
                           lmbda_nm,
                           voxel_nm,
                           grid_shape,
                           fresnel_approx=fresnel_approx,
                           sign_convention=sign_convention)
        h_real, h_imag = np.real(h), np.imag(h)
        h_real = w.create_variable(h_real, requires_grad=False, device=device)
        h_imag = w.create_variable(h_imag, requires_grad=False, device=device)

        t_tot = 0
        n_steps = int(np.ceil(n_slices / binning))
        i_slice = n_slices
        for i_step in range(n_steps):
            if return_intermediate_wavefields:
                intermediate_wavefield_real_ls.append(probe_real)
                intermediate_wavefield_imag_ls.append(probe_imag)
            # ==========================================
            # Sampling
            # ==========================================
            k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1.
            # At the start of bin, initialize slice array.
            if i_step == 0:
                this_step = n_slices % binning if n_slices % binning != 0 else binning
            else:
                this_step = binning
            if repeating_slice is None:
                if this_step > 1:
                    delta_slice = grid_batch[:, :, :,
                                             i_slice - this_step:i_slice, 0]
                else:
                    delta_slice = grid_batch[:, :, :, i_slice - 1, 0]
            else:
                delta_slice = grid_batch[:, :, :, 0:1, 0]
            if kappa is not None:
                # In sign = +1 convention, phase (delta) should be positive, and kappa is positive too.
                beta_slice = delta_slice * kappa
            else:
                if repeating_slice is None:
                    if this_step > 1:
                        beta_slice = grid_batch[:, :, :,
                                                i_slice - this_step:i_slice, 1]
                    else:
                        beta_slice = grid_batch[:, :, :, i_slice - 1, 1]
                else:
                    beta_slice = grid_batch[:, :, :, 0:1, 1]
            t0 = time.time()
            # ==========================================
            # Modulation
            # ==========================================
            if type == 'delta_beta':
                # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
                # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
                if this_step > 1:
                    delta_slice = w.sum(delta_slice, axis=3)
                    beta_slice = w.sum(beta_slice, axis=3)
                # exp(-ikn*)
                c_real, c_imag = w.exp_complex(
                    -k1 * beta_slice, sign_convention * k1 * delta_slice)
            elif type == 'real_imag':
                raise NotImplementedError('Backprop not done for real_imag.')
                if this_step > 1:
                    delta_slice = w.prod(delta_slice, axis=3)
                    beta_slice = w.prod(beta_slice, axis=3)
                c_real, c_imag = delta_slice, beta_slice
            else:
                raise ValueError(
                    'unknown_type must be delta_beta or real_imag.')
            probe_real, probe_imag = (probe_real * c_real -
                                      probe_imag * c_imag,
                                      probe_real * c_imag +
                                      probe_imag * c_real)
            # ==========================================
            # When arriving at the last slice of bin or object, do (back)propagation.
            # ==========================================
            if i_step < n_steps - 1:
                # Backpropagate over -z
                if this_step == binning:
                    probe_real, probe_imag = w.convolve_with_transfer_function(
                        probe_real, probe_imag, h_real, h_imag)
                else:
                    probe_real, probe_imag = fresnel_propagate(
                        probe_real,
                        probe_imag,
                        -delta_nm * this_step,
                        lmbda_nm,
                        voxel_nm,
                        device=device,
                        sign_convention=sign_convention)
            i_slice -= this_step
            t_tot += (time.time() - t0)

    if shift_exit_wave is not None:
        probe_real, probe_imag = realign_image_fourier(probe_real,
                                                       probe_imag,
                                                       shift_exit_wave,
                                                       axes=(1, 2),
                                                       device=device)

    if free_prop_cm not in [0, None]:
        if isinstance(free_prop_cm, str) and free_prop_cm == 'inf':
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            if sign_convention == 1:
                probe_real, probe_imag = w.fft2_and_shift(
                    probe_real,
                    probe_imag,
                    axes=[1, 2],
                    normalize=normalize_fft)
            else:
                probe_real, probe_imag = w.ifft2_and_shift(
                    probe_real,
                    probe_imag,
                    axes=[1, 2],
                    normalize=normalize_fft)
        else:
            dist_nm = free_prop_cm * 1e7
            l = np.prod(size_nm)**(1. / 3)
            if optimize_free_prop:
                probe_real, probe_imag = fresnel_propagate_wrapped(
                    u_free,
                    v_free,
                    probe_real,
                    probe_imag,
                    dist_nm,
                    lmbda_nm,
                    voxel_nm,
                    device=device,
                    sign_convention=sign_convention)
            elif not optimize_free_prop:
                probe_real, probe_imag = fresnel_propagate(
                    probe_real,
                    probe_imag,
                    dist_nm,
                    lmbda_nm,
                    voxel_nm,
                    device=device,
                    sign_convention=sign_convention)
    return_ls = [probe_real, probe_imag]
    if return_fft_time:
        return_ls.append(t_tot)
    if return_intermediate_wavefields:
        # intermediate_wavefield_real_ls = w.stack(intermediate_wavefield_real_ls)
        # intermediate_wavefield_imag_ls = w.stack(intermediate_wavefield_imag_ls)
        return_ls = return_ls + [
            intermediate_wavefield_real_ls, intermediate_wavefield_imag_ls
        ]
    return return_ls
Exemple #7
0
def multislice_propagate_batch(grid_batch, probe_real, probe_imag, energy_ev, psize_cm, delta_cm=None,
                               free_prop_cm=None, obj_batch_shape=None, kernel=None, fresnel_approx=True,
                               pure_projection=False, binning=1, device=None, type='delta_beta',
                               normalize_fft=False, sign_convention=1, optimize_free_prop=False, u_free=None, v_free=None,
                               scale_ri_by_k=True, is_minus_logged=False, pure_projection_return_sqrt=False,
                               kappa=None, repeating_slice=None, return_fft_time=False):

    minibatch_size = grid_batch.shape[0]
    grid_shape = grid_batch.shape[1:-1]
    if delta_cm is not None:
        voxel_nm = np.array([psize_cm, psize_cm, delta_cm]) * 1.e7
    else:
        voxel_nm = np.array([psize_cm] * 3) * 1.e7

    lmbda_nm = 1240. / energy_ev
    mean_voxel_nm = np.prod(voxel_nm) ** (1. / 3)
    size_nm = np.array(grid_shape) * voxel_nm

    n_slices = grid_batch.shape[-2]
    delta_nm = voxel_nm[-1]

    if repeating_slice is not None:
        n_slices = repeating_slice

    if pure_projection:
        k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1.
        if type == 'delta_beta':
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            p = w.sum(grid_batch, axis=-2)
            delta_slice = p[:, :, :, 0]
            if kappa is not None:
                beta_slice = delta_slice * kappa
            else:
                beta_slice = p[:, :, :, 1]
            # In conventional tomography beta is interpreted as mu. If projection data is minus-logged,
            # the line sum of beta (mu) directly equals image intensity. If raw_data_type is set to 'intensity',
            # measured data will be taken square root at the loss calculation step. To match this, the summed
            # beta must be square-rooted as well. Otherwise, set raw_data_type to 'magnitude' to avoid square-rooting
            # the measured data, and skip sqrt to summed beta here accordingly.
            if is_minus_logged:
                if pure_projection_return_sqrt:
                    c_real, c_imag = w.sqrt(beta_slice + 1e-10), delta_slice * 0
                else:
                    c_real, c_imag = beta_slice, delta_slice * 0
            else:
                c_real, c_imag = w.exp_complex(-k1 * beta_slice, -sign_convention * k1 * delta_slice)
        elif type == 'real_imag':
            p = w.prod(grid_batch, axis=-2)
            delta_slice = p[:, :, :, 0]
            beta_slice = p[:, :, :, 1]
            c_real, c_imag = delta_slice, beta_slice
            if is_minus_logged:
                if pure_projection_return_sqrt:
                    c_real, c_imag = w.sqrt(-w.log(c_real ** 2 + c_imag ** 2) + 1e-10), 0
                else:
                    c_real, c_imag = -w.log(c_real ** 2 + c_imag ** 2), 0
        else:
            raise ValueError('unknown_type must be real_imag or delta_beta.')
        probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real)

    else:
        if kernel is not None:
            h = kernel
        else:
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            h = get_kernel(delta_nm * binning, lmbda_nm, voxel_nm, grid_shape, fresnel_approx=fresnel_approx, sign_convention=sign_convention)
        h_real, h_imag = np.real(h), np.imag(h)
        h_real = w.create_variable(h_real, requires_grad=False, device=device)
        h_imag = w.create_variable(h_imag, requires_grad=False, device=device)

        i_bin = 0
        t_tot = 0
        for i in range(n_slices):
            k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1.
            # At the start of bin, initialize slice array.
            if repeating_slice is None:
                delta_slice = grid_batch[:, :, :, i, 0]
            else:
                delta_slice = grid_batch[:, :, :, 0, 0]
            if kappa is not None:
                # In sign = +1 convention, phase (delta) should be positive, and kappa is positive too.
                beta_slice = delta_slice * kappa
            else:
                if repeating_slice is None:
                    beta_slice = grid_batch[:, :, :, i, 1]
                else:
                    beta_slice = grid_batch[:, :, :, 0, 1]
            t0 = time.time()
            if type == 'delta_beta':
                # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
                # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
                c_real, c_imag = w.exp_complex(-k1 * beta_slice, -sign_convention * k1 * delta_slice)
            elif type == 'real_imag':
                c_real, c_imag = delta_slice, beta_slice
            else:
                raise ValueError('unknown_type must be delta_beta or real_imag.')
            probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real)
            i_bin += 1

            # When arriving at the last slice of bin or object, do propagation.
            if i_bin == binning or i == n_slices - 1:
                if i < n_slices - 1:
                    if i_bin == binning:
                        probe_real, probe_imag = w.convolve_with_transfer_function(probe_real, probe_imag, h_real, h_imag)
                    else:
                        probe_real, probe_imag = fresnel_propagate(probe_real, probe_imag, delta_nm * i_bin, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention)
                i_bin = 0
            t_tot += (time.time() - t0)

    if free_prop_cm not in [0, None]:
        if isinstance(free_prop_cm, str) and free_prop_cm == 'inf':
            # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta
            # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta
            if sign_convention == 1:
                probe_real, probe_imag = w.fft2_and_shift(probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft)
            else:
                probe_real, probe_imag = w.ifft2_and_shift(probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft)
        else:
            dist_nm = free_prop_cm * 1e7
            l = np.prod(size_nm)**(1. / 3)
            if optimize_free_prop:
                    probe_real, probe_imag = fresnel_propagate_wrapped(u_free, v_free, probe_real, probe_imag, dist_nm,
                                                                       lmbda_nm, voxel_nm,
                                                                       device=device, sign_convention=sign_convention)
            elif not optimize_free_prop:
                probe_real, probe_imag = fresnel_propagate(probe_real, probe_imag, dist_nm, lmbda_nm, voxel_nm,
                                                           device=device, sign_convention=sign_convention)
    if return_fft_time:
        return probe_real, probe_imag, t_tot
    else:
        return probe_real, probe_imag