def sparse_multislice_propagate_batch(u, v, grid_batch, probe_real, probe_imag, energy_ev, psize_cm, slice_pos_cm_ls, free_prop_cm=None, obj_batch_shape=None, fresnel_approx=True, device=None, type='delta_beta', normalize_fft=False, sign_convention=1, scale_ri_by_k=True): minibatch_size = grid_batch.shape[0] grid_shape = grid_batch.shape[1:-1] voxel_nm = np.array([psize_cm] * 3) * 1.e7 lmbda_nm = 1240. / energy_ev mean_voxel_nm = np.prod(voxel_nm) ** (1. / 3) size_nm = np.array(grid_shape) * voxel_nm slice_pos_nm_ls = slice_pos_cm_ls * 1e7 n_slices = grid_batch.shape[-2] delta_nm = voxel_nm[-1] for i in range(n_slices): # At the start of bin, initialize slice array. delta_slice = grid_batch[:, :, :, i, 0] beta_slice = grid_batch[:, :, :, i, 1] k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1. if type == 'delta_beta': # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta c_real, c_imag = w.exp_complex(-k1 * beta_slice, -sign_convention * k1 * delta_slice) elif type == 'real_imag': c_real, c_imag = delta_slice, beta_slice else: raise ValueError('unknown_type must be delta_beta or real_imag.') probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real) if i < n_slices - 1: probe_real, probe_imag = fresnel_propagate_wrapped(u, v, probe_real, probe_imag, slice_pos_nm_ls[i + 1] - slice_pos_nm_ls[i], lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention) if free_prop_cm not in [0, None]: if free_prop_cm == 'inf': if sign_convention == 1: probe_real, probe_imag = w.fft2_and_shift(probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft) else: probe_real, probe_imag = w.ifft2_and_shift(probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft) else: dist_nm = free_prop_cm * 1e7 l = np.prod(size_nm)**(1. / 3) crit_samp = lmbda_nm * dist_nm / l probe_real, probe_imag = fresnel_propagate(probe_real, probe_imag, dist_nm, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention) return probe_real, probe_imag
def alt_reconstruction_epie(obj_real, obj_imag, probe_real, probe_imag, probe_pos, probe_pos_correction, prj, device_obj=None, minibatch_size=1, alpha=1., n_epochs=100, **kwargs): """ Reconstruct a 2D object and probe function using ePIE. """ with w.no_grad(): probe_real = probe_real[0] probe_imag = probe_imag[0] probe_pos = probe_pos.astype(int) energy_ev = kwargs['energy_ev'] psize_cm = kwargs['psize_cm'] output_folder = kwargs['output_folder'] raw_data_type = kwargs['raw_data_type'] this_obj_size = obj_real.shape if len(probe_real) == 2: probe_size = probe_real.shape else: probe_size = probe_imag.shape obj_stack = w.stack([obj_real, obj_imag], axis=3) # Pad if needed obj_stack, pad_arr = pad_object(obj_stack, this_obj_size, probe_pos, probe_size, unknown_type='real_imag') i_batch = 0 subobj_ls = [] probe_real_ls = [] probe_imag_ls = [] for i_epoch in range(n_epochs): for j in range(len(probe_pos)): print('Batch {}/{}; Epoch {}/{}.'.format( j, len(probe_pos), i_epoch, n_epochs)) pos = probe_pos[j] pos[0] = pos[0] + pad_arr[0, 0] pos[1] = pos[1] + pad_arr[1, 0] subobj = obj_stack[pos[0]:pos[0] + probe_size[0], pos[1]:pos[1] + probe_size[1], :, :] subobj_ls.append(subobj) if len(w.nonzero(probe_pos_correction > 1e-3)) > 0: this_shift = probe_pos_correction[0, j] probe_real_shifted, probe_imag_shifted = realign_image_fourier( probe_real, probe_imag, this_shift, axes=(0, 1), device=device_obj) else: probe_real_shifted = probe_real probe_imag_shifted = probe_imag probe_real_ls.append(probe_real_shifted) probe_imag_ls.append(probe_imag_shifted) i_batch += 1 if i_batch < minibatch_size and i_batch < prj.shape[1]: continue else: this_prj_batch = prj[0, j * minibatch_size:j * minibatch_size + i_batch, :, :] this_prj_batch = w.create_variable(this_prj_batch, requires_grad=False, device=device_obj) if raw_data_type == 'intensity': this_prj_batch = w.sqrt(this_prj_batch) subobj_ls = w.stack(subobj_ls) probe_real_ls = w.stack(probe_real_ls) probe_imag_ls = w.stack(probe_imag_ls) c_real, c_imag = subobj_ls[:, :, :, 0, 0], subobj_ls[:, :, :, 0, 1] ex_real_ls, ex_imag_ls = (probe_real_ls * c_real - probe_imag_ls * c_imag, probe_real_ls * c_imag + probe_imag_ls * c_real) dp_real_ls, dp_imag_ls = w.fft2_and_shift( ex_real_ls, ex_imag_ls) mag_replace_factor = this_prj_batch / w.sqrt( dp_real_ls**2 + dp_imag_ls**2) dp_real_ls = dp_real_ls * mag_replace_factor dp_imag_ls = dp_imag_ls * mag_replace_factor phi_real_ls, phi_imag_ls = w.ishift_and_ifft2( dp_real_ls, dp_imag_ls) d_real_ls = phi_real_ls - ex_real_ls d_imag_ls = phi_imag_ls - ex_imag_ls norm = w.max(probe_real_ls**2 + probe_imag_ls**2) o_up_real = (probe_real_ls * d_real_ls + probe_imag_ls * d_imag_ls) / norm o_up_imag = (probe_real_ls * d_imag_ls - probe_imag_ls * d_real_ls) / norm o_up = w.stack([o_up_real, o_up_imag], axis=-1) o_up = w.reshape( o_up, [i_batch, probe_size[0], probe_size[1], 1, 2]) subobj_ls = subobj_ls + alpha * o_up norm = w.max(subobj_ls[:, :, :, 0, 0]**2 + subobj_ls[:, :, :, 0, 1]**2) p_up_real = (subobj_ls[:, :, :, 0, 0] * d_real_ls + subobj_ls[:, :, :, 0, 1] * d_imag_ls) / norm p_up_imag = (subobj_ls[:, :, :, 0, 0] * d_imag_ls - subobj_ls[:, :, :, 0, 1] * d_real_ls) / norm p_up = w.stack([p_up_real, p_up_imag], axis=-1) p_up = w.reshape( p_up, [i_batch, probe_size[0], probe_size[1], 1, 2]) p_up = w.mean(p_up, axis=0) probe_real = probe_real + alpha * p_up probe_imag = probe_imag + alpha * p_up # Put back. for i, k in enumerate( range(j * minibatch_size, j * minibatch_size + i_batch)): pos = probe_pos[k] pos[0] = pos[0] + pad_arr[0, 0] pos[1] = pos[1] + pad_arr[1, 0] obj_stack[pos[0]:pos[0] + probe_size[0], pos[1]:pos[1] + probe_size[1], :, :] = subobj_ls[i] i_batch = 0 subobj_ls = [] probe_real_ls = [] probe_imag_ls = [] fname0 = 'obj_mag_{}_{}'.format(i_epoch, i_batch) fname1 = 'obj_phase_{}_{}'.format(i_epoch, i_batch) obj0, obj1 = w.split_channel(obj_stack) obj0 = w.to_numpy(obj0) obj1 = w.to_numpy(obj1) dxchange.write_tiff(np.sqrt(obj0**2 + obj1**2), os.path.join(output_folder, fname0), dtype='float32', overwrite=True) dxchange.write_tiff(np.arctan2(obj1, obj0), os.path.join(output_folder, fname1), dtype='float32', overwrite=True)
def multislice_backpropagate_batch(grid_batch, probe_real, probe_imag, energy_ev, psize_cm, delta_cm=None, free_prop_cm=None, obj_batch_shape=None, kernel=None, fresnel_approx=True, pure_projection=False, binning=1, device=None, type='delta_beta', normalize_fft=False, sign_convention=1, optimize_free_prop=False, u_free=None, v_free=None, scale_ri_by_k=True, is_minus_logged=False, pure_projection_return_sqrt=False, kappa=None, repeating_slice=None, return_fft_time=False, shift_exit_wave=None, return_intermediate_wavefields=False): intermediate_wavefield_real_ls = [] intermediate_wavefield_imag_ls = [] minibatch_size = grid_batch.shape[0] grid_shape = grid_batch.shape[1:-1] if delta_cm is not None: voxel_nm = np.array([psize_cm, psize_cm, delta_cm]) * 1.e7 else: voxel_nm = np.array([psize_cm] * 3) * 1.e7 lmbda_nm = 1240. / energy_ev mean_voxel_nm = np.prod(voxel_nm)**(1. / 3) size_nm = np.array(grid_shape) * voxel_nm n_slices = grid_batch.shape[-2] delta_nm = voxel_nm[-1] if repeating_slice is not None: n_slices = repeating_slice if pure_projection: k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1. if type == 'delta_beta': # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta p = w.sum(grid_batch, axis=-2) delta_slice = p[:, :, :, 0] if kappa is not None: beta_slice = delta_slice * kappa else: beta_slice = p[:, :, :, 1] # In conventional tomography beta is interpreted as mu. If projection data is minus-logged, # the line sum of beta (mu) directly equals image intensity. If raw_data_type is set to 'intensity', # measured data will be taken square root at the loss calculation step. To match this, the summed # beta must be square-rooted as well. Otherwise, set raw_data_type to 'magnitude' to avoid square-rooting # the measured data, and skip sqrt to summed beta here accordingly. if is_minus_logged: if pure_projection_return_sqrt: c_real, c_imag = w.sqrt(beta_slice + 1e-10), delta_slice * 0 else: c_real, c_imag = beta_slice, delta_slice * 0 else: # exp(-ikn*) c_real, c_imag = w.exp_complex( -k1 * beta_slice, sign_convention * k1 * delta_slice) elif type == 'real_imag': raise NotImplementedError('Backprop not done for real_imag.') p = w.prod(grid_batch, axis=-2) delta_slice = p[:, :, :, 0] beta_slice = p[:, :, :, 1] c_real, c_imag = delta_slice, beta_slice if is_minus_logged: if pure_projection_return_sqrt: c_real, c_imag = w.sqrt(-w.log(c_real**2 + c_imag**2) + 1e-10), 0 else: c_real, c_imag = -w.log(c_real**2 + c_imag**2), 0 else: raise ValueError('unknown_type must be real_imag or delta_beta.') probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real) else: if kernel is not None: h = kernel else: # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta # Negative distance for backpropagation. h = get_kernel(-delta_nm * binning, lmbda_nm, voxel_nm, grid_shape, fresnel_approx=fresnel_approx, sign_convention=sign_convention) h_real, h_imag = np.real(h), np.imag(h) h_real = w.create_variable(h_real, requires_grad=False, device=device) h_imag = w.create_variable(h_imag, requires_grad=False, device=device) t_tot = 0 n_steps = int(np.ceil(n_slices / binning)) i_slice = n_slices for i_step in range(n_steps): if return_intermediate_wavefields: intermediate_wavefield_real_ls.append(probe_real) intermediate_wavefield_imag_ls.append(probe_imag) # ========================================== # Sampling # ========================================== k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1. # At the start of bin, initialize slice array. if i_step == 0: this_step = n_slices % binning if n_slices % binning != 0 else binning else: this_step = binning if repeating_slice is None: if this_step > 1: delta_slice = grid_batch[:, :, :, i_slice - this_step:i_slice, 0] else: delta_slice = grid_batch[:, :, :, i_slice - 1, 0] else: delta_slice = grid_batch[:, :, :, 0:1, 0] if kappa is not None: # In sign = +1 convention, phase (delta) should be positive, and kappa is positive too. beta_slice = delta_slice * kappa else: if repeating_slice is None: if this_step > 1: beta_slice = grid_batch[:, :, :, i_slice - this_step:i_slice, 1] else: beta_slice = grid_batch[:, :, :, i_slice - 1, 1] else: beta_slice = grid_batch[:, :, :, 0:1, 1] t0 = time.time() # ========================================== # Modulation # ========================================== if type == 'delta_beta': # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta if this_step > 1: delta_slice = w.sum(delta_slice, axis=3) beta_slice = w.sum(beta_slice, axis=3) # exp(-ikn*) c_real, c_imag = w.exp_complex( -k1 * beta_slice, sign_convention * k1 * delta_slice) elif type == 'real_imag': raise NotImplementedError('Backprop not done for real_imag.') if this_step > 1: delta_slice = w.prod(delta_slice, axis=3) beta_slice = w.prod(beta_slice, axis=3) c_real, c_imag = delta_slice, beta_slice else: raise ValueError( 'unknown_type must be delta_beta or real_imag.') probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real) # ========================================== # When arriving at the last slice of bin or object, do (back)propagation. # ========================================== if i_step < n_steps - 1: # Backpropagate over -z if this_step == binning: probe_real, probe_imag = w.convolve_with_transfer_function( probe_real, probe_imag, h_real, h_imag) else: probe_real, probe_imag = fresnel_propagate( probe_real, probe_imag, -delta_nm * this_step, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention) i_slice -= this_step t_tot += (time.time() - t0) if shift_exit_wave is not None: probe_real, probe_imag = realign_image_fourier(probe_real, probe_imag, shift_exit_wave, axes=(1, 2), device=device) if free_prop_cm not in [0, None]: if isinstance(free_prop_cm, str) and free_prop_cm == 'inf': # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta if sign_convention == 1: probe_real, probe_imag = w.fft2_and_shift( probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft) else: probe_real, probe_imag = w.ifft2_and_shift( probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft) else: dist_nm = free_prop_cm * 1e7 l = np.prod(size_nm)**(1. / 3) if optimize_free_prop: probe_real, probe_imag = fresnel_propagate_wrapped( u_free, v_free, probe_real, probe_imag, dist_nm, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention) elif not optimize_free_prop: probe_real, probe_imag = fresnel_propagate( probe_real, probe_imag, dist_nm, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention) return_ls = [probe_real, probe_imag] if return_fft_time: return_ls.append(t_tot) if return_intermediate_wavefields: # intermediate_wavefield_real_ls = w.stack(intermediate_wavefield_real_ls) # intermediate_wavefield_imag_ls = w.stack(intermediate_wavefield_imag_ls) return_ls = return_ls + [ intermediate_wavefield_real_ls, intermediate_wavefield_imag_ls ] return return_ls
def multislice_propagate_batch(grid_batch, probe_real, probe_imag, energy_ev, psize_cm, delta_cm=None, free_prop_cm=None, obj_batch_shape=None, kernel=None, fresnel_approx=True, pure_projection=False, binning=1, device=None, type='delta_beta', normalize_fft=False, sign_convention=1, optimize_free_prop=False, u_free=None, v_free=None, scale_ri_by_k=True, is_minus_logged=False, pure_projection_return_sqrt=False, kappa=None, repeating_slice=None, return_fft_time=False): minibatch_size = grid_batch.shape[0] grid_shape = grid_batch.shape[1:-1] if delta_cm is not None: voxel_nm = np.array([psize_cm, psize_cm, delta_cm]) * 1.e7 else: voxel_nm = np.array([psize_cm] * 3) * 1.e7 lmbda_nm = 1240. / energy_ev mean_voxel_nm = np.prod(voxel_nm) ** (1. / 3) size_nm = np.array(grid_shape) * voxel_nm n_slices = grid_batch.shape[-2] delta_nm = voxel_nm[-1] if repeating_slice is not None: n_slices = repeating_slice if pure_projection: k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1. if type == 'delta_beta': # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta p = w.sum(grid_batch, axis=-2) delta_slice = p[:, :, :, 0] if kappa is not None: beta_slice = delta_slice * kappa else: beta_slice = p[:, :, :, 1] # In conventional tomography beta is interpreted as mu. If projection data is minus-logged, # the line sum of beta (mu) directly equals image intensity. If raw_data_type is set to 'intensity', # measured data will be taken square root at the loss calculation step. To match this, the summed # beta must be square-rooted as well. Otherwise, set raw_data_type to 'magnitude' to avoid square-rooting # the measured data, and skip sqrt to summed beta here accordingly. if is_minus_logged: if pure_projection_return_sqrt: c_real, c_imag = w.sqrt(beta_slice + 1e-10), delta_slice * 0 else: c_real, c_imag = beta_slice, delta_slice * 0 else: c_real, c_imag = w.exp_complex(-k1 * beta_slice, -sign_convention * k1 * delta_slice) elif type == 'real_imag': p = w.prod(grid_batch, axis=-2) delta_slice = p[:, :, :, 0] beta_slice = p[:, :, :, 1] c_real, c_imag = delta_slice, beta_slice if is_minus_logged: if pure_projection_return_sqrt: c_real, c_imag = w.sqrt(-w.log(c_real ** 2 + c_imag ** 2) + 1e-10), 0 else: c_real, c_imag = -w.log(c_real ** 2 + c_imag ** 2), 0 else: raise ValueError('unknown_type must be real_imag or delta_beta.') probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real) else: if kernel is not None: h = kernel else: # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta h = get_kernel(delta_nm * binning, lmbda_nm, voxel_nm, grid_shape, fresnel_approx=fresnel_approx, sign_convention=sign_convention) h_real, h_imag = np.real(h), np.imag(h) h_real = w.create_variable(h_real, requires_grad=False, device=device) h_imag = w.create_variable(h_imag, requires_grad=False, device=device) i_bin = 0 t_tot = 0 for i in range(n_slices): k1 = 2. * PI * delta_nm / lmbda_nm if scale_ri_by_k else 1. # At the start of bin, initialize slice array. if repeating_slice is None: delta_slice = grid_batch[:, :, :, i, 0] else: delta_slice = grid_batch[:, :, :, 0, 0] if kappa is not None: # In sign = +1 convention, phase (delta) should be positive, and kappa is positive too. beta_slice = delta_slice * kappa else: if repeating_slice is None: beta_slice = grid_batch[:, :, :, i, 1] else: beta_slice = grid_batch[:, :, :, 0, 1] t0 = time.time() if type == 'delta_beta': # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta c_real, c_imag = w.exp_complex(-k1 * beta_slice, -sign_convention * k1 * delta_slice) elif type == 'real_imag': c_real, c_imag = delta_slice, beta_slice else: raise ValueError('unknown_type must be delta_beta or real_imag.') probe_real, probe_imag = (probe_real * c_real - probe_imag * c_imag, probe_real * c_imag + probe_imag * c_real) i_bin += 1 # When arriving at the last slice of bin or object, do propagation. if i_bin == binning or i == n_slices - 1: if i < n_slices - 1: if i_bin == binning: probe_real, probe_imag = w.convolve_with_transfer_function(probe_real, probe_imag, h_real, h_imag) else: probe_real, probe_imag = fresnel_propagate(probe_real, probe_imag, delta_nm * i_bin, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention) i_bin = 0 t_tot += (time.time() - t0) if free_prop_cm not in [0, None]: if isinstance(free_prop_cm, str) and free_prop_cm == 'inf': # Use sign_convention = 1 for Goodman convention: exp(ikz); n = 1 - delta + i * beta # Use sign_convention = -1 for opposite convention: exp(-ikz); n = 1 - delta - i * beta if sign_convention == 1: probe_real, probe_imag = w.fft2_and_shift(probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft) else: probe_real, probe_imag = w.ifft2_and_shift(probe_real, probe_imag, axes=[1, 2], normalize=normalize_fft) else: dist_nm = free_prop_cm * 1e7 l = np.prod(size_nm)**(1. / 3) if optimize_free_prop: probe_real, probe_imag = fresnel_propagate_wrapped(u_free, v_free, probe_real, probe_imag, dist_nm, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention) elif not optimize_free_prop: probe_real, probe_imag = fresnel_propagate(probe_real, probe_imag, dist_nm, lmbda_nm, voxel_nm, device=device, sign_convention=sign_convention) if return_fft_time: return probe_real, probe_imag, t_tot else: return probe_real, probe_imag