def multidistance_ctf_wrapped(this_prj_batch, free_prop_cm, energy_ev, psize_cm, kappa=50, safe_zone_width=0, prj_affine_ls=None, device=None): u_free, v_free = gen_freq_mesh( np.array([psize_cm * 1e7] * 3), [this_prj_batch.shape[i + 1] + 2 * safe_zone_width for i in range(2)]) u_free = w.create_variable(u_free, requires_grad=False, device=device) v_free = w.create_variable(v_free, requires_grad=False, device=device) this_prj_batch = w.create_variable(this_prj_batch, requires_grad=False, device=device) if prj_affine_ls is not None: for i in range(len(prj_affine_ls)): this_prj_batch[i] = w.affine_transform(this_prj_batch[i:i + 1], prj_affine_ls[i]) if safe_zone_width > 0: this_prj_batch = w.pad(this_prj_batch, [(0, 0), (safe_zone_width, safe_zone_width), (safe_zone_width, safe_zone_width)], mode='edge') this_prj_batch_ft_r, this_prj_batch_ft_i = w.fft2(this_prj_batch - 1, w.zeros_like( this_prj_batch, requires_grad=False, device=device), normalize=True) dist_nm_ls = free_prop_cm * 1e7 prj_real_ls = [] prj_imag_ls = [] lmbda_nm = 1240. / energy_ev for i in range(len(dist_nm_ls)): xi = PI * lmbda_nm * dist_nm_ls[i] * (u_free**2 + v_free**2) prj_real_ls.append( (w.sin(xi) + 1. / kappa * w.cos(xi)) * this_prj_batch_ft_r[i]) prj_imag_ls.append( (w.sin(xi) + 1. / kappa * w.cos(xi)) * this_prj_batch_ft_i[i]) this_prj_batch_ft_r = w.sum(w.stack(prj_real_ls), axis=0) this_prj_batch_ft_i = w.sum(w.stack(prj_imag_ls), axis=0) osc_ls = [] for i in range(len(dist_nm_ls)): xi = PI * lmbda_nm * dist_nm_ls[i] * (u_free**2 + v_free**2) osc_ls.append(2 * (w.sin(xi) + 1. / kappa * w.cos(xi))**2) osc = w.sum(w.stack(osc_ls), axis=0) + 1e-10 a_real = this_prj_batch_ft_r / osc a_imag = this_prj_batch_ft_i / osc phase, _ = w.ifft2(a_real, a_imag, normalize=True) return phase[safe_zone_width:phase.shape[0] - safe_zone_width, safe_zone_width:phase.shape[1] - safe_zone_width]
def update_parameter_gradients(opt_ls, grads): for opt in opt_ls: if opt.name == 'obj': continue elif opt.name == 'probe': opt.grads += w.stack(grads[1:3], axis=-1) else: opt.grads += grads[opt.index_in_grad_returns] return opt_ls
def save_param_arrays_to_checkpoint(self): path = os.path.join(self.output_folder, 'checkpoint') create_directory_multirank(path) if len(self.params_list) > 0: arr = [] for i, param_name in enumerate(self.params_list): arr.append(self.params_whole_array_dict[param_name]) arr = w.stack(arr) np.save(os.path.join(path, 'opt_params_checkpoint.npy'), w.to_numpy(arr)) return
def save_distributed_param_arrays_to_checkpoint(self): path = os.path.join(self.output_folder, 'checkpoint') if not os.path.exists(path): os.makedirs(path) if len(self.params_list) > 0: arr = [] for i, param_name in enumerate(self.params_list): arr.append(self.params_whole_array_dict[param_name]) arr = w.stack(arr) np.save( os.path.join(path, 'opt_params_checkpoint_rank_{}.npy'.format(rank)), w.to_numpy(arr)) return
def apply_finite_support_mask_to_array(self, mask, unknown_type='delta_beta', device=None): assert isinstance(mask, Mask) with w.no_grad(): if unknown_type == 'delta_beta': delta = self.arr[:, :, :, 0] * mask.mask beta = self.arr[:, :, :, 1] * mask.mask elif unknown_type == 'real_imag': ones_arr = w.ones(self.arr.shape[:-1], requires_grad=False, device=device) zeros_arr = w.zeros(self.arr.shape[:-1], requires_grad=False, device=device) delta = self.arr[:, :, :, 0] * mask.mask + ones_arr * (1 - mask.mask) beta = self.arr[:, :, :, 1] * mask.mask + zeros_arr * (1 - mask.mask) self.arr = w.stack([delta, beta], -1) w.reattach(self.arr)
def alt_reconstruction_epie(obj_real, obj_imag, probe_real, probe_imag, probe_pos, probe_pos_correction, prj, device_obj=None, minibatch_size=1, alpha=1., n_epochs=100, **kwargs): """ Reconstruct a 2D object and probe function using ePIE. """ with w.no_grad(): probe_real = probe_real[0] probe_imag = probe_imag[0] probe_pos = probe_pos.astype(int) energy_ev = kwargs['energy_ev'] psize_cm = kwargs['psize_cm'] output_folder = kwargs['output_folder'] raw_data_type = kwargs['raw_data_type'] this_obj_size = obj_real.shape if len(probe_real) == 2: probe_size = probe_real.shape else: probe_size = probe_imag.shape obj_stack = w.stack([obj_real, obj_imag], axis=3) # Pad if needed obj_stack, pad_arr = pad_object(obj_stack, this_obj_size, probe_pos, probe_size, unknown_type='real_imag') i_batch = 0 subobj_ls = [] probe_real_ls = [] probe_imag_ls = [] for i_epoch in range(n_epochs): for j in range(len(probe_pos)): print('Batch {}/{}; Epoch {}/{}.'.format( j, len(probe_pos), i_epoch, n_epochs)) pos = probe_pos[j] pos[0] = pos[0] + pad_arr[0, 0] pos[1] = pos[1] + pad_arr[1, 0] subobj = obj_stack[pos[0]:pos[0] + probe_size[0], pos[1]:pos[1] + probe_size[1], :, :] subobj_ls.append(subobj) if len(w.nonzero(probe_pos_correction > 1e-3)) > 0: this_shift = probe_pos_correction[0, j] probe_real_shifted, probe_imag_shifted = realign_image_fourier( probe_real, probe_imag, this_shift, axes=(0, 1), device=device_obj) else: probe_real_shifted = probe_real probe_imag_shifted = probe_imag probe_real_ls.append(probe_real_shifted) probe_imag_ls.append(probe_imag_shifted) i_batch += 1 if i_batch < minibatch_size and i_batch < prj.shape[1]: continue else: this_prj_batch = prj[0, j * minibatch_size:j * minibatch_size + i_batch, :, :] this_prj_batch = w.create_variable(this_prj_batch, requires_grad=False, device=device_obj) if raw_data_type == 'intensity': this_prj_batch = w.sqrt(this_prj_batch) subobj_ls = w.stack(subobj_ls) probe_real_ls = w.stack(probe_real_ls) probe_imag_ls = w.stack(probe_imag_ls) c_real, c_imag = subobj_ls[:, :, :, 0, 0], subobj_ls[:, :, :, 0, 1] ex_real_ls, ex_imag_ls = (probe_real_ls * c_real - probe_imag_ls * c_imag, probe_real_ls * c_imag + probe_imag_ls * c_real) dp_real_ls, dp_imag_ls = w.fft2_and_shift( ex_real_ls, ex_imag_ls) mag_replace_factor = this_prj_batch / w.sqrt( dp_real_ls**2 + dp_imag_ls**2) dp_real_ls = dp_real_ls * mag_replace_factor dp_imag_ls = dp_imag_ls * mag_replace_factor phi_real_ls, phi_imag_ls = w.ishift_and_ifft2( dp_real_ls, dp_imag_ls) d_real_ls = phi_real_ls - ex_real_ls d_imag_ls = phi_imag_ls - ex_imag_ls norm = w.max(probe_real_ls**2 + probe_imag_ls**2) o_up_real = (probe_real_ls * d_real_ls + probe_imag_ls * d_imag_ls) / norm o_up_imag = (probe_real_ls * d_imag_ls - probe_imag_ls * d_real_ls) / norm o_up = w.stack([o_up_real, o_up_imag], axis=-1) o_up = w.reshape( o_up, [i_batch, probe_size[0], probe_size[1], 1, 2]) subobj_ls = subobj_ls + alpha * o_up norm = w.max(subobj_ls[:, :, :, 0, 0]**2 + subobj_ls[:, :, :, 0, 1]**2) p_up_real = (subobj_ls[:, :, :, 0, 0] * d_real_ls + subobj_ls[:, :, :, 0, 1] * d_imag_ls) / norm p_up_imag = (subobj_ls[:, :, :, 0, 0] * d_imag_ls - subobj_ls[:, :, :, 0, 1] * d_real_ls) / norm p_up = w.stack([p_up_real, p_up_imag], axis=-1) p_up = w.reshape( p_up, [i_batch, probe_size[0], probe_size[1], 1, 2]) p_up = w.mean(p_up, axis=0) probe_real = probe_real + alpha * p_up probe_imag = probe_imag + alpha * p_up # Put back. for i, k in enumerate( range(j * minibatch_size, j * minibatch_size + i_batch)): pos = probe_pos[k] pos[0] = pos[0] + pad_arr[0, 0] pos[1] = pos[1] + pad_arr[1, 0] obj_stack[pos[0]:pos[0] + probe_size[0], pos[1]:pos[1] + probe_size[1], :, :] = subobj_ls[i] i_batch = 0 subobj_ls = [] probe_real_ls = [] probe_imag_ls = [] fname0 = 'obj_mag_{}_{}'.format(i_epoch, i_batch) fname1 = 'obj_phase_{}_{}'.format(i_epoch, i_batch) obj0, obj1 = w.split_channel(obj_stack) obj0 = w.to_numpy(obj0) obj1 = w.to_numpy(obj1) dxchange.write_tiff(np.sqrt(obj0**2 + obj1**2), os.path.join(output_folder, fname0), dtype='float32', overwrite=True) dxchange.write_tiff(np.arctan2(obj1, obj0), os.path.join(output_folder, fname1), dtype='float32', overwrite=True)
def update_parameters(opt_ls, optimizable_params, kwargs): i_epoch = kwargs['i_epoch'] i_batch = kwargs['i_batch'] n_batch = kwargs['n_batch'] other_params_update_delay = kwargs['other_params_update_delay'] probe_update_delay = kwargs['probe_update_delay'] probe_update_limit = kwargs['probe_update_limit'] i_full_angle = kwargs['i_full_angle'] stdout_options = kwargs['stdout_options'] if probe_update_limit is None: probe_update_limit = np.inf for opt in opt_ls: if opt.name == 'obj': continue elif opt.name == 'probe': if i_batch + i_epoch * n_batch >= probe_update_delay and i_batch + i_epoch * n_batch < probe_update_limit: with w.no_grad(): opt.grads = comm.allreduce(opt.grads) probe_temp = opt.apply_gradient( w.stack([ optimizable_params['probe_real'], optimizable_params['probe_imag'] ], axis=-1), opt.grads, i_full_angle, **opt.options_dict) optimizable_params['probe_real'], optimizable_params[ 'probe_imag'] = w.split_channel(probe_temp) del opt.grads, probe_temp w.reattach(optimizable_params['probe_real']) w.reattach(optimizable_params['probe_imag']) else: print_flush( 'Probe is not updated because current batch is out of the specified range ({}, {}).' .format(probe_update_delay, probe_update_limit), 0, rank, **stdout_options) elif i_batch + i_epoch * n_batch >= other_params_update_delay: if opt.name == 'probe_pos_correction': with w.no_grad(): opt.grads = comm.allreduce(opt.grads) probe_pos_correction = optimizable_params[ 'probe_pos_correction'] probe_pos_correction = opt.apply_gradient( probe_pos_correction, opt.grads, i_full_angle, **opt.options_dict) # Prevent position drifting slicer = tuple(range(len(probe_pos_correction.shape) - 1)) optimizable_params[ 'probe_pos_correction'] = probe_pos_correction - w.mean( probe_pos_correction, axis=slicer) w.reattach(optimizable_params['probe_pos_correction']) elif opt.name == 'slice_pos_cm_ls': with w.no_grad(): opt.grads = comm.allreduce(opt.grads) slice_pos_cm_ls = optimizable_params['slice_pos_cm_ls'] slice_pos_cm_ls = opt.apply_gradient( slice_pos_cm_ls, opt.grads, i_full_angle, **opt.options_dict) # Prevent position drifting optimizable_params[ 'slice_pos_cm_ls'] = slice_pos_cm_ls - slice_pos_cm_ls[ 0] w.reattach(optimizable_params['slice_pos_cm_ls']) elif opt.name == 'prj_affine_ls': with w.no_grad(): opt.grads = comm.allreduce(opt.grads) optimizable_params['prj_affine_ls'] = opt.apply_gradient( optimizable_params['prj_affine_ls'], opt.grads, i_full_angle, **opt.options_dict) # Regularize transformation of image 0. optimizable_params['prj_affine_ls'][0, 0, 0] = 1. optimizable_params['prj_affine_ls'][0, 0, 1] = 0. optimizable_params['prj_affine_ls'][0, 0, 2] = 0. optimizable_params['prj_affine_ls'][0, 1, 0] = 0. optimizable_params['prj_affine_ls'][0, 1, 1] = 1. optimizable_params['prj_affine_ls'][0, 1, 2] = 0. w.reattach(optimizable_params['prj_affine_ls']) else: with w.no_grad(): opt.grads = comm.allreduce(opt.grads) var = optimizable_params[opt.name] optimizable_params[opt.name] = opt.apply_gradient( var, opt.grads, i_full_angle, **opt.options_dict) w.reattach(optimizable_params[opt.name]) else: print_flush( 'Params are not updated because current epoch is smaller than specified delay ({}).' .format(other_params_update_delay), 0, rank, **stdout_options) return optimizable_params