def _checkPropGaussianFit(self) -> bool: """Experimental feature. To be used primarily for any debugging.""" from ptychoSampling.wavefront import checkPropagationType logger.warn( "the check for propagation type is an experimental feature.") self._calculateGaussianFit() feature_size = np.array(self._gaussian_fwhm) * 2 propagation_type = checkPropagationType( wavelength=self.wavelength, prop_dist=self.defocus_dist, source_pixel_size=self.pixel_size, max_feature_size=feature_size) errors = { -1: "Propagation type is different along the x and y directions. This is not supported.", 0: "Defocus distance is too small (Fresnel number too high) for transfer function method.", 2: "Defocus distance too large. Only near field defocus supported." } if propagation_type in errors: e = ValueError(errors[propagation_type]) logger.error(e) raise e return True
def _calculateDiffractionPatterns(self): wv = self.probe.wavefront intensities_all = [] self._transfer_function = None ny, nx = self.obj.bordered_array.shape if self.scan_grid.subpixel_scan: e = ValueError("Subpixel scan not supported for nearfield.") logger.error(e) raise e for i, (py, px) in enumerate(self.scan_grid.positions_pix): exit_wave = wv.ifftshift exit_wave[py:py + ny, px:px + nx] *= self.obj.bordered_array exit_wave = exit_wave.fftshift if self._transfer_function is None: self._transfer_function = np.zeros_like(exit_wave) det_wave = exit_wave.propTF( prop_dist=self.detector.obj_dist, transfer_function=self._transfer_function) else: det_wave = exit_wave.propTF( reuse_transfer_function=True, transfer_function=self._transfer_function) intensities_all.append(det_wave.intensities) self.intensities = np.random.poisson( intensities_all) if self.poisson_noise else np.array( intensities_all)
def _checkAttr(self, attr_to_check, attr_this): if not hasattr(self, attr_to_check): e = AttributeError( f"First attach a {attr_to_check} before attaching {attr_this}." ) logger.error(e) raise e
def attachTensorflowOptimizerForVariable( self, variable_name: str, optimizer_type: str, optimizer_init_args: dict = None, optimizer_minimize_args: dict = None, initial_update_delay: int = 0, update_frequency: int = 1, checkpoint_frequency: int = 100): """Attach an optimizer for the specified variable. Parameters ---------- variable_name : str Name (string) associated with the chosen variable (to be optimized) in the forward model. optimizer_type : str Name of the standard optimizer chosen from availabe options in options.Options. optimizer_init_args : dict Dictionary containing the key-value pairs required for the initialization of the desired optimizer. optimizer_minimize_args : dict Dictionary containing the key-value pairs required to define the minimize operation for the desired optimizer. initial_update_delay : int Number of iterations to wait before the minimizer is first applied. Defaults to 0. update_frequency : int Number of iterations in between minimization calls. Defaults to 1. checkpoint_frequency : int Number of iterations between creation of checkpoints of the optimizer. Not implemented. """ optimization_all = ptychoSampling.reconstruction.options.OPTIONS[ "tf_optimization_methods"] self._checkConfigProperty(optimization_all, optimizer_type) self._checkAttr("_train_loss_t", "optimizer") if variable_name not in self.fwd_model.model_vars: e = ValueError( f"{variable_name} is not a supported variable in {self.fwd_model}" ) logger.error(e) raise e var = self.fwd_model.model_vars[variable_name]["variable"] if optimizer_minimize_args is None: optimizer_minimize_args = {} elif ("loss" in optimizer_minimize_args) or ("var_list" in optimizer_minimize_args): warning = ( "Target loss and optimization variable are assigned by default. " + "If custom processing is desired, use _attachCustomOptimizerForVariable directly." ) logger.warning(warning) optimizer_minimize_args["loss"] = self._train_loss_t optimizer_minimize_args["var_list"] = [var] self._attachCustomOptimizerForVariable( optimization_all[optimizer_type], optimizer_init_args, optimizer_minimize_args, initial_update_delay, update_frequency)
def addCustomMetricToDataLog(self, title: str, tensor: tf.Tensor, log_epoch_frequency: int = 1, registration_ground_truth: np.ndarray = None, registration: bool = True, normalized_lse: bool = False): """Registration metric type only applies if registration ground truth is not none.""" if registration_ground_truth is None: self.datalog.addCustomTensorMetric( title=title, tensor=tensor, log_epoch_frequency=log_epoch_frequency) else: if registration and normalized_lse: e = ValueError( "Only one of 'registration' or 'normalized lse' should be true." ) logger.error(e) self.datalog.addCustomTensorMetric( title=title, tensor=tensor, registration=registration, normalized_lse=normalized_lse, log_epoch_frequency=log_epoch_frequency, true=registration_ground_truth)
def __init__(self, *args: Any, focal_length: Optional[float] = None, aperture_radius: Optional[float] = None, oversampling: bool = True, oversampling_npix: int = 1024, **kwargs: Any) -> None: super().__init__(*args, **kwargs) if ((self.width_dist[0] != self.width_dist[1]) or (self.pixel_size[0] != self.pixel_size[1]) or (self.shape[0] != self.shape[1])): e = ValueError( "Only a circularly symmetric focussed probe is supported. Require equal x and y widths." ) logger.error(e) raise e if self.center_dist[0] > 0: e = ValueError( "Only a centrally located focussed probe is supported (i.e. center at origin). If " + "necessary, we can simply customize the wavefront manually (pad and roll)." ) logger.error(e) raise e self.focal_length = focal_length self.aperture_radius = aperture_radius self.oversampling = oversampling self.oversampling_npix = oversampling_npix self._calculateWavefront()
def _checkConfigProperty(options: dict, key_to_check: str): if key_to_check not in options: e = ValueError( f"{key_to_check} is not currently supported. " + f"Check if {key_to_check} exists as an option among {options} in options.py" ) logger.error(e) raise e
def __post_init__(self): if self.registration and self.true is None: e = ValueError("Require true data for registration.") logger.error(e) raise e if not self.registration: self.columns = [self.title] else: appends = ['_err'] # , '_shift', '_phase'] self.columns = [self.title + string for string in appends]
def _calculateWavefront(self) -> None: """Calculating the airy pattern then propagating it by `defocus_dist`.""" if self.width_dist[0] > 0: logger.warning( "Warning: if width at the focus is supplied, " + "any supplied focal_length and aperture radius are ignored.") self.aperture_radius = None self.focal_length = None # Assuming resolution equal to pixel pitch and focal length of 10.0 m. focal_length = 10.0 focus_radius = self.width_dist[0] / 2 # note that jinc(1,1) = 1.22 * np.pi aperture_radius = (self.wavelength * 1.22 * np.pi * focal_length / 2 / np.pi / focus_radius) else: if None in [self.aperture_radius, self.focal_length]: e = ValueError( 'Either focus_radius_npix or BOTH aperture_radius and focal_length must be supplied.' ) logger.error(e) raise e aperture_radius = self.aperture_radius focal_length = self.focal_length npix_oversampled = max( self.oversampling_npix, self.shape[-1]) if self.oversampling else self.shape[0] pixel_pitch_aperture = self.wavelength * focal_length / ( npix_oversampled * self.pixel_size[0]) x = np.arange(-npix_oversampled // 2, npix_oversampled // 2) * pixel_pitch_aperture r = np.sqrt(x**2 + x[:, np.newaxis]**2).astype('float32') circ_wavefront = np.zeros(r.shape) circ_wavefront[r < aperture_radius] = 1.0 circ_wavefront[r == aperture_radius] = 0.5 probe_vals = np.fft.fftshift( np.fft.fft2(np.fft.fftshift(circ_wavefront), norm='ortho')) n1 = npix_oversampled // 2 - self.shape[0] // 2 n2 = npix_oversampled // 2 + self.shape[1] // 2 scaling_factor = np.sqrt(self.n_photons / np.sum(np.abs(probe_vals)**2)) wavefront_array = probe_vals[n1:n2, n1:n2].astype( 'complex64') * scaling_factor wavefront_array = np.fft.fftshift(wavefront_array) #self.wavefront = self.wavefront.update(array=wavefront_array) self.wavefront = Wavefront(wavefront_array, wavelength=self.wavelength, pixel_size=self.pixel_size) self._propagateWavefront()
def __init__(self, shape=(128, 128), border_shape=((32, 32), (32, 32)), border_const=1.0, mod_range=1.0, phase_range=np.pi): if len(shape) != 2: e = ValueError('Supplied shape is not 2d.') logger.error(e) raise e super().__init__(shape, border_shape, border_const, mod_range, phase_range) self._createObj()
def _sanityChecks(self): sanity_checks = { 'probe_det': self._probe.npix == self._detector.npix * self.upsampling_factor, 'probe_obj': self._probe.pixel_size == self._obj.pixel_size, 'obj_grid': self._obj.pixel_size == self._scan_grid.obj_pixel_size, 'probe_grid': self._probe.npix == self._scan_grid.probe_npix } for s, v in sanity_checks.items(): if not v: e = ValueError(f"Mismatch in supplied {s} parameters.") logger.error(e) raise e
def __post_init__(self) -> None: try: _ = np.array( self.intensity_patterns) + self.background_intensity_level except Exception as e: e2 = ValueError( 'Invalid format for background intensity level. Should be compatible to the ' + 'supplied intensity_patterns parameter value through a numpy broadcasting operation.' ) logger.error([e, e2]) raise e2 from e for key in self.additional_params: object.__setattr__(self, key, self.additional_params[key])
def _register(test: np.ndarray, true: np.ndarray) -> float: if len(test.shape) == 2: registration_fn = _register_translation_2d elif len(test.shape) == 3: registration_fn = _register_translation_3d else: e = ValueError( "Subpixel registration only available for 2d and 3d objects.") logger.error(e) raise e shift, err, phase = registration_fn(test, true, upsample_factor=10) shift, err, phase = registration_fn(test * np.exp(-1j * phase), true, upsample_factor=10) return err
def checkOverlap(self, overlap_ratio: float = 0.5) -> bool: if not self.full_field_probe: overlap_nx = overlap_ratio * self.probe_shape[1] overlap_ny = overlap_ratio * self.probe_shape[0] else: overlap_nx = overlap_ratio * self.obj_w_border_shape[1] overlap_ny = overlap_ratio * self.obj_w_border_shape[0] for p in self.positions_pix: differences = np.abs(self.positions_pix - p) ydiffs = differences[:, 0] xdiffs = differences[:, 1] xmin = np.min(xdiffs[xdiffs > 0]) ymin = np.min(ydiffs[ydiffs > 0]) if xmin > overlap_nx or ymin > overlap_ny: e = ValueError( "Insufficient overlap between adjacent scan positions.") logger.error(e) raise e return True
def __init__(self, mesh_shape: Tuple[int, int, int] = (128, 128, 128), mod_const: float = 0.5, border_shape=((0,0), (0,0), (0,0)), border_const=0.0, pixel_size=None): if len(mesh_shape) != 3: e = ValueError('Supplied shape is not 3d.') logger.error(e) raise e self.mod_constant = mod_const self.border_const = border_const self._border_shape = border_shape self.pixel_size = pixel_size self.mod_range = None self.phase_range = None self._createObj() self.array *= self.mod_constant
def __init__(self, obj: Obj, probe: Probe3D, scan_grid: BraggPtychoGrid, exit_wave_axis: str = "y", upsampling_factor: int = 1, setup_second_order: bool = False, dtype: str = "float32"): if scan_grid.full_field_probe: e = ValueError( "Full field probe not supported for Bragg ptychography.") logger.error(e) raise e if scan_grid.grid2d_axes != ("y", "z"): e = ValueError( "Only supports the case where the ptychographic scan is on the yz-plane." ) logger.error(e) raise e if exit_wave_axis != 'y': e = ValueError( "Only supports the case where the exit waves are output along the y-direction." ) logger.error(e) raise e super().__init__(obj, probe, scan_grid, upsampling_factor, setup_second_order, dtype) logger.info("Creating the phase modulations for the scan angles.") with tf.device("/gpu:0"): self._probe_phase_modulations_all_t = tf.constant( self._getProbePhaseModulationsStack(), dtype='complex64') self._full_rc_positions_indices_t = tf.constant( scan_grid.full_rc_positions_indices, dtype='int64')
def _setObjArrayValues(self, values: Optional[np.ndarray] = None) -> None: """Set the obj transmission function and add the border. Performs sanity checks for the 'shape' and '_border_shape' parameters supplied when the class is created. The 'shape' parameter should be tuple-like and composed of integers, formatted so that 'numpy.zeros' accepts it as an argument for the 'shape' parameter. The '_border_shape' should be formatted so that 'numpy.pad' accepts it as an argument for the 'pad' parameter. Sets the values for the 'array' and 'bordered_array' attributes. Parameters ---------- values : array_like, optional Obj array values. For the default value 'None', the function creates an array of zeros. """ if values is None: try: values = np.zeros(self.shape) except Exception as e: e2 = ValueError("Error in input obj shape.") logger.error([e, e2]) raise e2 from e #self._array = values try: self.bordered_array = np.pad( values, #self._array, self._border_shape, mode='constant', constant_values=self.border_const) except Exception as e: e2 = ValueError("Error in border specifications.") logger.error([e, e2]) raise e2 from e array_slices = tuple( slice(b[0], self.shape[i] + b[0]) for i, b in enumerate(self._border_shape)) # This is only a view to bordered_array. Changing one changes the other. self._array = self.bordered_array[array_slices]
def __init__(self, obj_w_border_shape: Tuple[int, int], probe_shape: Tuple[int, int], obj_pixel_size: Tuple[float, float], step_dist: Tuple[float, float], subpixel_scan: bool = False, scan_grid_boundary_pix: np.ndarray = None, full_field_probe: bool = False): self.obj_w_border_shape = obj_w_border_shape self.obj_pixel_size = obj_pixel_size self.probe_shape = probe_shape self.step_dist = np.array(step_dist) self.subpixel_scan = subpixel_scan self.full_field_probe = full_field_probe if scan_grid_boundary_pix is None: if not self.full_field_probe: self.scan_grid_boundary_pix = np.array( [[0, self.obj_w_border_shape[0]], [0, self.obj_w_border_shape[1]]]) else: self.scan_grid_boundary_pix = np.array([[0, probe_shape[0]], [0, probe_shape[1]]]) else: try: _ = np.reshape(scan_grid_boundary_pix, (2, 2)) except Exception as e: e2 = ValueError( "scan_grid_boundary_pix should contain integers and have shape [[y_min, y_max], " + "[x_min, x_max]]") logger.error([e, e2]) raise e2 from e self.scan_grid_boundary_pix = scan_grid_boundary_pix self.positions_pix = [] self.positions_subpix = [] self.positions_dist = []
def __init__(self, *args: int, loss_type: str = "least_squared", obj_array_true: np.ndarray = None, probe_wavefront_true: np.ndarray = None, update_delay_probe: int = 0, update_delay_obj: int = 0, reconstruct_probe: bool = True, registration_log_frequency: int = 10, opt_init_extra_kwargs: dict = None, obj_abs_proj: bool = True, loss_init_extra_kwargs: dict = None, r_factor_log: bool = True, apply_precond: bool = False, **kwargs: int): logger.info('initializing...') super().__init__(*args, **kwargs) if self.training_batch_size != self.n_train: e = ValueError( "Conjugate gradient reconstruction does not support minibatch reconstruction." ) logger.error(e) raise e logger.info('attaching fwd model...') self._attachCustomForwardModel(JointFarfieldForwardModelT, obj_abs_proj=obj_abs_proj) logger.info('creating loss fn...') self.attachLossFunctionSecondOrder( loss_type, loss_init_extra_kwargs=loss_init_extra_kwargs) logger.info("creating optimizers...") if opt_init_extra_kwargs is None: opt_init_extra_kwargs = {} self._preconditioner = None if apply_precond: with self.graph.as_default(): loss_data_type = self._loss_method.data_type hessian_t = ( tf.ones_like(self._batch_train_predictions_t) * self._train_hessian_fn(self._batch_train_predictions_t)) hessian_t = tf.reshape(hessian_t, [-1, *self.probe.shape]) obj_scaling = self._getObjScaling(loss_data_type, hessian_t) probe_scaling = self._getProbeScaling(loss_data_type, hessian_t) scaling_both = tf.concat((obj_scaling, probe_scaling), axis=0) scaling_both = tf.concat((scaling_both, scaling_both), axis=0) zero_condition = tf.less(scaling_both, 1e-10 * tf.reduce_max(scaling_both)) zero_case = tf.ones_like(scaling_both) / ( 1e-8 * tf.reduce_max(scaling_both)) self._preconditioner = tf.where(zero_condition, zero_case, 1 / scaling_both) opt_init_args = { "input_var": self.fwd_model.joint_v, "predictions_fn": self._predictions_fn, "loss_fn": self._train_loss_fn, "name": "opt", "diag_precondition_t": self._preconditioner } opt_init_args.update(opt_init_extra_kwargs) self._attachCustomOptimizerForVariable( ConjugateGradientOptimizer, optimizer_init_args=opt_init_args) self.addCustomMetricToDataLog( title="ls_iters", tensor=self.optimizers[0]._optimizer._linesearch_steps, log_epoch_frequency=1) self.addCustomMetricToDataLog( title="alpha", tensor=self.optimizers[0]._optimizer._linesearch._alpha, log_epoch_frequency=1) if obj_array_true is not None: self.addCustomMetricToDataLog( title="obj_error", tensor=self.fwd_model.obj_cmplx_t, log_epoch_frequency=registration_log_frequency, registration_ground_truth=obj_array_true) if reconstruct_probe and (probe_wavefront_true is not None): self.addCustomMetricToDataLog( title="probe_error", tensor=self.fwd_model.probe_cmplx_t, log_epoch_frequency=registration_log_frequency, registration_ground_truth=probe_wavefront_true) self._addRFactorLog(r_factor_log, registration_log_frequency)
def __init__(self, *args: int, loss_type: str = "least_squared", obj_array_true: np.ndarray = None, probe_wavefront_true: np.ndarray = None, max_cg_iter: int = 100, min_cg_tol: float = 0.1, registration_log_frequency: int = 10, both_registration_nlse: bool = True, opt_init_extra_kwargs: dict = None, obj_abs_proj: bool = True, obj_abs_max: float = 1.0, probe_abs_max: float = None, loss_init_extra_kwargs: dict = None, r_factor_log: bool = True, update_delay_probe: int = 0, reconstruct_probe: bool = True, apply_diag_mu_scaling: bool = True, apply_precond: bool = False, apply_precond_and_scaling: bool = False, stochastic_diag_estimator_type: str = None, stochastic_diag_estimator_iters: int = 1, **kwargs: int): logger.info('initializing...') super().__init__(*args, **kwargs) if apply_precond_and_scaling: logger.info( "Overriding any set values for apply_precond and apply_diag_mu_scaling" ) apply_precond = True apply_diag_mu_scaling = True if stochastic_diag_estimator_type is not None: if apply_diag_mu_scaling or apply_precond: e = ValueError( "Cannot use analytical precond and diag ggn elems if stochastic calculation is enabled." ) logger.error(e) raise e if self.training_batch_size != self.n_train: e = ValueError( "LMA reconstruction does not support minibatch reconstruction." ) logger.error(e) raise e logger.info('attaching fwd model...') self._attachCustomForwardModel(JointFarfieldForwardModelT, obj_abs_proj=obj_abs_proj, obj_abs_max=obj_abs_max, probe_abs_max=probe_abs_max) logger.info('creating loss fn...') print('Loss init args', loss_init_extra_kwargs) self.attachLossFunctionSecondOrder( loss_type, loss_init_extra_kwargs=loss_init_extra_kwargs) logger.info("creating optimizer...") if opt_init_extra_kwargs is None: opt_init_extra_kwargs = {} if apply_diag_mu_scaling or apply_precond: with self.graph.as_default(): loss_data_type = self._loss_method.data_type hessian_t = ( tf.ones_like(self._batch_train_predictions_t) * self._train_hessian_fn(self._batch_train_predictions_t)) hessian_t = tf.reshape(hessian_t, [-1, *self.probe.shape]) print(hessian_t) self._obj_scaling = self._getObjScaling( loss_data_type, hessian_t) self._probe_scaling = self._getProbeScaling( loss_data_type, hessian_t) scaling_both = tf.concat( (self._obj_scaling, self._probe_scaling), axis=0) if apply_diag_mu_scaling: self._joint_scaling_t = tf.concat( (scaling_both, scaling_both), axis=0) optimizer = ScaledLMAOptimizer opt_init_extra_kwargs[ 'diag_mu_scaling_t'] = self._joint_scaling_t if apply_precond: self._joint_precond_t = tf.concat( (scaling_both, scaling_both), axis=0) optimizer = PCGLMAOptimizer opt_init_extra_kwargs[ 'diag_precond_t'] = self._joint_precond_t if apply_diag_mu_scaling: optimizer = ScaledPCGLMAOptimizer # else: optimizer = LMAOptimizer opt_init_args = { "input_var": self.fwd_model.joint_v, "predictions_fn": self._predictions_fn, "loss_fn": self._train_loss_fn, "diag_hessian_fn": self._train_hessian_fn, "name": "opt", "max_cg_iter": max_cg_iter, "min_cg_tol": min_cg_tol, "stochastic_diag_estimator_type": stochastic_diag_estimator_type, "stochastic_diag_estimator_iters": stochastic_diag_estimator_iters, "assert_tolerances": False } opt_init_args.update(opt_init_extra_kwargs) self._attachCustomOptimizerForVariable( optimizer, optimizer_init_args=opt_init_args) self.addCustomMetricToDataLog( title="cg_iters", tensor=self.optimizers[0]._optimizer._total_cg_iterations, log_epoch_frequency=1) self.addCustomMetricToDataLog( title="ls_iters", tensor=self.optimizers[0]._optimizer._total_proj_ls_iterations, log_epoch_frequency=1) self.addCustomMetricToDataLog(title="proj_iters", tensor=self.optimizers[0]._optimizer. _projected_gradient_iterations, log_epoch_frequency=1) self.addCustomMetricToDataLog(title="mu", tensor=self.optimizers[0]._optimizer._mu, log_epoch_frequency=1) self._registration_log_frequency = registration_log_frequency if obj_array_true is not None: self._setObjRegistration( obj_array_true, both_registration_nlse=both_registration_nlse) if probe_wavefront_true is not None: self._setProbeRegistration( probe_wavefront_true, both_registration_nlse=both_registration_nlse) self._addRFactorLog(r_factor_log, registration_log_frequency)
def __init__(self, wavelength: float = 1.5e-10, obj: Obj = None, probe_3d: Probe3D = None, scan_grid: BraggPtychoGrid = None, detector: Detector = None, poisson_noise: bool = True, upsampling_factor: int = 1) -> None: self.wavelength = wavelength self.upsampling_factor = upsampling_factor self.poisson_noise = poisson_noise two_theta = scan_grid.two_theta if scan_grid is not None else AngleGridParams( ).two_theta if detector is not None: logger.info("Using supplied detector info.") self.detector = detector else: logger.info("Creating new detector.") self.detector = Detector(**dt.asdict(DetectorParams())) det_3d_shape = (1, *self.detector.shape) obj_xz_nyquist_support = self.wavelength * self.detector.obj_dist / np.array( self.detector.pixel_size) obj_xz_pixel_size = obj_xz_nyquist_support / ( np.array(self.detector.shape) * self.upsampling_factor) # The y pixel size is very ad-hoc obj_y_pixel_size = obj_xz_nyquist_support[1] / self.detector.shape[1] obj_pixel_size = (obj_y_pixel_size, *obj_xz_pixel_size) if obj is not None: logger.info("Using supplied object.") self.obj = obj else: self.obj = self.createObj(dt.asdict(ObjParams()), det_3d_shape, obj_pixel_size, upsampling_factor) probe_xz_shape = np.array(self.detector.shape) * self.upsampling_factor if probe_3d is not None: logger.info("Using supplied 3d probe values.") self.probe_3d = probe_3d # Note that the probe y shape cannot be determined from the other supplied parameters (I think). if (np.any(probe_3d.shape[1:] != probe_xz_shape) or (probe_3d.wavelength != self.wavelength) or np.any(probe_3d.pixel_size != obj_pixel_size)): e = ValueError( f"Mismatch between the supplied probe and the supplied scan parameters." ) logger.error(e) raise e else: self.probe_3d = self.createProbe3D(dt.asdict(Probe2DParams()), wavelength, np.pi / 2 - two_theta, probe_xz_shape, obj_pixel_size) rotate_angle = np.pi / 2 - two_theta probe_y_pix_before_rotation = self.probe_3d.shape[0] * np.cos( rotate_angle) // 1 pady, bordered_obj_ypix = self.calculateObjBorderAfterRotation( rotate_angle, probe_y_pix_before_rotation, self.obj.shape[0], self.obj.shape[1]) if self.obj.bordered_array.shape[0] < bordered_obj_ypix: logger.warning( "Adding zero padding to the object in the y-direction so that the overall object y-width " + "covers the entirety of the feasible probe positions.") self.obj.border_shape = ((pady, pady), self.obj.border_shape[1], self.obj.border_shape[2]) if scan_grid is not None: logger.info("Using supplied scan grid.") self.scan_grid = scan_grid else: elems_yz = lambda tup: np.array(tup)[[0, 2]] scan_grid_params_dict = dt.asdict( Scan2DGridParams(obj_pixel_size=elems_yz(self.obj.pixel_size))) self.scan_grid = self.createScanGrid( scan_grid_params_dict, dt.asdict(AngleGridParams()), elems_yz(self.obj.bordered_array.shape), elems_yz(self.probe_3d.shape)) self._calculateDiffractionPatterns()
def __init__(self, wavelength: float = 1.5e-10, obj: Obj = None, obj_params: dict = {}, probe: Probe = None, probe_params: dict = {}, scan_grid: ScanGrid = None, scan_grid_params: dict = {}, detector: Detector = None, detector_params: dict = {}, poisson_noise: bool = True, upsampling_factor: int = 1) -> None: self.wavelength = wavelength self.upsampling_factor = upsampling_factor self.poisson_noise = poisson_noise if obj or probe or scan_grid or detector: logger.warning( "If one (or all) of obj, probe, scan_grid, or detector is supplied, " + "then the corresponding _params parameter is ignored.") if detector is not None: self.detector = detector self._detector_params = {} else: self._detector_params = DetectorParams(**detector_params) self.detector = Detector(**dt.asdict(self._detector_params)) obj_pixel_size = np.array( self.detector.pixel_size) * self.upsampling_factor probe_shape = np.array(self.detector.shape) * self.upsampling_factor if obj is not None: if obj.pixel_size is not None and np.any( obj.pixel_size != obj_pixel_size): e = ValueError( "Mismatch between the provided pixel size and the pixel size calculated from scan " + "parameters.") logger.error(e) raise e obj.pixel_size = obj_pixel_size else: self._obj_params = ObjParams(**obj_params) self.obj = Simulated2DObj(**dt.asdict(self._obj_params)) if probe is not None: check = (np.any(probe.shape != probe_shape) or (probe.wavelength != self.wavelength) or np.any(probe.pixel_size != obj_pixel_size)) if check: e = ValueError( "Supplied probe parameters do not match with supplied scan and detector parameters." ) logger.error(e) raise e self.probe = probe else: self._probe_params = ProbeParams(**probe_params) self.probe = GaussianSpeckledProbe(wavelength=wavelength, pixel_size=obj_pixel_size, shape=probe_shape, **dt.asdict(self._probe_params)) if scan_grid is not None: self.scan_grid = scan_grid self._scan_grid_params = None else: self._scan_grid_params = GridParams(**scan_grid_params) self.scan_grid = RectangleGrid( obj_w_border_shape=self.obj.bordered_array.shape, probe_shape=self.probe.shape, obj_pixel_size=obj_pixel_size, **dt.asdict(self._scan_grid_params)) self.scan_grid.checkOverlap() self._calculateDiffractionPatterns()
def __init__(self, wavelength: float = 1.5e-10, obj: Obj = None, obj_params: dict = None, probe: Probe = None, probe_params: dict = None, scan_grid: ScanGrid = None, scan_grid_params: dict = None, detector: Detector = None, detector_params: dict = None, poisson_noise: bool = True, random_shift_pix: Tuple[int, int] = None, upsampling_factor: int = 1, background_scaling: float=1e-8, background_constant: float=0) -> None: self.wavelength = wavelength self.upsampling_factor = upsampling_factor self.poisson_noise = poisson_noise self.background_scaling = background_scaling self.background_constant = background_constant if obj or probe or scan_grid or detector: logger.warning("If one (or all) of obj, probe, scan_grid, or detector is supplied, " + "then the corresponding _params parameter is ignored.") if detector is not None: self.detector = detector self._detector_params = {} else: detector_params = {} if detector_params is None else detector_params self._detector_params = DetectorParams(**detector_params) self.detector = Detector(**dt.asdict(self._detector_params)) detector_support_size = np.asarray(self.detector.pixel_size) * self.detector.shape obj_pixel_size = self.wavelength * self.detector.obj_dist / (detector_support_size * self.upsampling_factor) probe_shape = np.array(self.detector.shape) * self.upsampling_factor if obj is not None: if obj.pixel_size is not None and np.any(obj.pixel_size != obj_pixel_size): e = ValueError("Mismatch between the provided pixel size and the pixel size calculated from scan " + "parameters.") logger.error(e) raise e obj.pixel_size = obj_pixel_size else: obj_params = {} if obj_params is None else obj_params self._obj_params = ObjParams(**obj_params) self.obj = Simulated2DObj(**dt.asdict(self._obj_params)) if probe is not None: check = (np.any(probe.shape != probe_shape) or (probe.wavelength != self.wavelength) or np.any(probe.pixel_size != obj_pixel_size)) if check: e = ValueError("Supplied probe parameters do not match with supplied scan and detector parameters.") logger.error(e) raise e self.probe = probe else: probe_params = {} if probe_params is None else probe_params self._probe_params = ProbeParams(**probe_params) self.probe = FocusCircularProbe(wavelength=wavelength, pixel_size=obj_pixel_size, shape=probe_shape, **dt.asdict(self._probe_params)) if scan_grid is not None: self.scan_grid = scan_grid self._scan_grid_params = None else: scan_grid_params = {} if scan_grid_params is None else scan_grid_params self._scan_grid_params = GridParams(**scan_grid_params) self.scan_grid = RectangleGrid(obj_w_border_shape=self.obj.bordered_array.shape, probe_shape=self.probe.shape, obj_pixel_size=obj_pixel_size, **dt.asdict(self._scan_grid_params)) #if self._scan_grid_params.random_shift_dist is not None: # raise if random_shift_pix is not None: self.scan_grid = self._updateGridWithRandomPixShifts(self.scan_grid, random_shift_pix) #elif self._scan_grid_params.random_shift_ratio is not None: # raise self.scan_grid.checkOverlap() self._calculateDiffractionPatterns()
def __init__(self, *args: int, obj_array_true: np.ndarray = None, probe_wavefront_true: np.ndarray = None, obj_abs_proj: bool = True, update_delay_probe: int = 0, update_delay_obj: int = 0, update_delay_aux: int = 0, update_frequency_probe: int = 1, update_frequency_obj: int = 1, update_frequency_aux: int = 1, reconstruct_probe: bool = True, registration_log_frequency: int = 10, r_factor_log: bool = True, loss_type: str = "gaussian", learning_rate_scaling_probe: float = 1.0, learning_rate_scaling_obj: float = 1.0, loss_init_extra_kwargs: dict = None, **kwargs: int): logger.info('initializing...') super().__init__(*args, **kwargs) if self.training_batch_size != self.n_train: e = ValueError("PALM reconstruction does not support minibatch reconstruction.") logger.error(e) raise e if loss_type != "gaussian": e = ValueError("PALM reconstruction does not support other loss functions.") logger.error(e) raise e logger.info('attaching fwd model...') self._attachCustomForwardModel(FarfieldPALMForwardModel, obj_abs_proj=obj_abs_proj) logger.info('creating loss fn...') self._attachCustomLossFunction(AuxLossFunctionT, loss_init_extra_kwargs) logger.info("create learning rates") self._learning_rate_scaling_probe = learning_rate_scaling_probe self._learning_rate_scaling_obj = learning_rate_scaling_obj self._lr_obj = self._getObjLearningRate() * learning_rate_scaling_obj self._lr_probe = self._getProbeLearningRate() * learning_rate_scaling_probe logger.info('creating optimizers...') aux_new_t = self._getAuxUpdate() self._attachCustomOptimizerForVariable(AuxAssignOptimizer, optimizer_init_args={"aux_v": self.fwd_model.aux_v, "aux_new_t": aux_new_t}, initial_update_delay=update_delay_aux, update_frequency=update_frequency_aux) update_delay_obj = update_delay_obj if update_delay_obj is not None else 0 self.attachTensorflowOptimizerForVariable("obj", optimizer_type="gradient", optimizer_init_args={"learning_rate": self._lr_obj}, initial_update_delay=update_delay_obj, update_frequency=update_frequency_obj) update_delay_probe = update_delay_probe if update_delay_probe is not None else 0 if reconstruct_probe: self.attachTensorflowOptimizerForVariable("probe", optimizer_type="gradient", optimizer_init_args={"learning_rate": self._lr_probe}, initial_update_delay=update_delay_probe, update_frequency=update_frequency_probe) if obj_array_true is not None: self.addCustomMetricToDataLog(title="obj_error", tensor=self.fwd_model.obj_cmplx_t, log_epoch_frequency=registration_log_frequency, registration_ground_truth=obj_array_true) if reconstruct_probe and (probe_wavefront_true is not None): self.addCustomMetricToDataLog(title="probe_error", tensor=self.fwd_model.probe_cmplx_t, log_epoch_frequency=registration_log_frequency, registration_ground_truth=probe_wavefront_true) self._addRFactorLog(r_factor_log, 1)