def getFlopsPerIter(self) -> int: if self.n_validation > 0: e = RuntimeWarning("n_validation > 0 can give misleading flops.") logger.warning(e) if hasattr(self, "session"): e = RuntimeWarning( "Calculating computational cost after previously initializing the static graph can " + "include inessential calculations (like training and/or validation loss value) " + "and thus give misleading results.") logger.warning(e) else: with self.graph.as_default(): config = tf.ConfigProto() config.gpu_options.allow_growth = True logger.info("Initializing the session.") self.session = tf.Session(config=config) self.session.run(tf.global_variables_initializer()) for opt in self.optimizers: self.session.run(opt.minimize_op) total_flops = getComputationalCostInFlops(self.graph) return total_flops
def attachTensorflowOptimizerForVariable( self, variable_name: str, optimizer_type: str, optimizer_init_args: dict = None, optimizer_minimize_args: dict = None, initial_update_delay: int = 0, update_frequency: int = 1, checkpoint_frequency: int = 100): """Attach an optimizer for the specified variable. Parameters ---------- variable_name : str Name (string) associated with the chosen variable (to be optimized) in the forward model. optimizer_type : str Name of the standard optimizer chosen from availabe options in options.Options. optimizer_init_args : dict Dictionary containing the key-value pairs required for the initialization of the desired optimizer. optimizer_minimize_args : dict Dictionary containing the key-value pairs required to define the minimize operation for the desired optimizer. initial_update_delay : int Number of iterations to wait before the minimizer is first applied. Defaults to 0. update_frequency : int Number of iterations in between minimization calls. Defaults to 1. checkpoint_frequency : int Number of iterations between creation of checkpoints of the optimizer. Not implemented. """ optimization_all = ptychoSampling.reconstruction.options.OPTIONS[ "tf_optimization_methods"] self._checkConfigProperty(optimization_all, optimizer_type) self._checkAttr("_train_loss_t", "optimizer") if variable_name not in self.fwd_model.model_vars: e = ValueError( f"{variable_name} is not a supported variable in {self.fwd_model}" ) logger.error(e) raise e var = self.fwd_model.model_vars[variable_name]["variable"] if optimizer_minimize_args is None: optimizer_minimize_args = {} elif ("loss" in optimizer_minimize_args) or ("var_list" in optimizer_minimize_args): warning = ( "Target loss and optimization variable are assigned by default. " + "If custom processing is desired, use _attachCustomOptimizerForVariable directly." ) logger.warning(warning) optimizer_minimize_args["loss"] = self._train_loss_t optimizer_minimize_args["var_list"] = [var] self._attachCustomOptimizerForVariable( optimization_all[optimizer_type], optimizer_init_args, optimizer_minimize_args, initial_update_delay, update_frequency)
def createObj(obj_params_dict: dict, det_3d_shape: Tuple[int, int, int], obj_pixel_size: Tuple[float, float, float], upsampling_factor: float = 1.0) -> Obj: logger.info("Creating new crystal cell.") while True: obj_crystal = Simulated3DCrystalCell(**obj_params_dict) if det_3d_shape[ 1] * upsampling_factor >= obj_crystal.bordered_array.shape[ 1]: break else: logger.warning( "Generated object width is larger than detector x-width. Trying again." ) logger.info( 'Adding x and z borders to crystal cell based on detector parameters.' ) # Ensuring that the number of pixels in the x dimension matches that in the detector. padx = (det_3d_shape[1] - obj_crystal.bordered_array.shape[1]) // 2 # Calculating the padding needed to accomodate all feasible probe translations (w probe-obj interaction) padz, _ = Simulation.calculateObjBorderAfterRotation( 0, det_3d_shape[2], obj_crystal.bordered_array.shape[1], 0) pad_shape = (0, padx, padz) obj = CustomObjFromArray(array=obj_crystal.bordered_array, border_shape=np.vstack( (pad_shape, pad_shape)).T, border_const=0, pixel_size=obj_pixel_size) return obj
def _checkFinalized(self): if not hasattr(self, "dataframe"): e = AttributeError( "Cannot add item to the log file after starting the optimization. " + "The log file remains unchanged. Only the print output is affected." ) logger.warning(e)
def _calculateWavefront(self) -> None: """Calculating the airy pattern then propagating it by `defocus_dist`.""" if self.width_dist[0] > 0: logger.warning( "Warning: if width at the focus is supplied, " + "any supplied focal_length and aperture radius are ignored.") self.aperture_radius = None self.focal_length = None # Assuming resolution equal to pixel pitch and focal length of 10.0 m. focal_length = 10.0 focus_radius = self.width_dist[0] / 2 # note that jinc(1,1) = 1.22 * np.pi aperture_radius = (self.wavelength * 1.22 * np.pi * focal_length / 2 / np.pi / focus_radius) else: if None in [self.aperture_radius, self.focal_length]: e = ValueError( 'Either focus_radius_npix or BOTH aperture_radius and focal_length must be supplied.' ) logger.error(e) raise e aperture_radius = self.aperture_radius focal_length = self.focal_length npix_oversampled = max( self.oversampling_npix, self.shape[-1]) if self.oversampling else self.shape[0] pixel_pitch_aperture = self.wavelength * focal_length / ( npix_oversampled * self.pixel_size[0]) x = np.arange(-npix_oversampled // 2, npix_oversampled // 2) * pixel_pitch_aperture r = np.sqrt(x**2 + x[:, np.newaxis]**2).astype('float32') circ_wavefront = np.zeros(r.shape) circ_wavefront[r < aperture_radius] = 1.0 circ_wavefront[r == aperture_radius] = 0.5 probe_vals = np.fft.fftshift( np.fft.fft2(np.fft.fftshift(circ_wavefront), norm='ortho')) n1 = npix_oversampled // 2 - self.shape[0] // 2 n2 = npix_oversampled // 2 + self.shape[1] // 2 scaling_factor = np.sqrt(self.n_photons / np.sum(np.abs(probe_vals)**2)) wavefront_array = probe_vals[n1:n2, n1:n2].astype( 'complex64') * scaling_factor wavefront_array = np.fft.fftshift(wavefront_array) #self.wavefront = self.wavefront.update(array=wavefront_array) self.wavefront = Wavefront(wavefront_array, wavelength=self.wavelength, pixel_size=self.pixel_size) self._propagateWavefront()
def getFlopsPerIter(self) -> dict: if self.n_validation > 0: e = RuntimeWarning("n_validation > 0 can give misleading flops.") logger.warning(e) if hasattr(self, "session"): e = RuntimeWarning( "Calculating computational cost after previously initializing the static graph can " + "include inessential calculations (like training and/or validation loss value) " + "and thus give misleading results.") logger.warning(e) else: self.finalizeSetup() self.session.run(self._new_train_batch_op) for opt in self.optimizers: self.session.run(opt.minimize_op) total_flops = getComputationalCostInFlops(self.graph) cg_init_flops = getComputationalCostInFlops( self.graph, keywords=[("opt_minimize_step", "conjugate_gradient", "cg_init")], # "obj_opt_minimize_step/while/conjugate_gradient/cg_while"], exclude_keywords=False) cg_while_flops = getComputationalCostInFlops( self.graph, keywords=[("opt_minimize_step", "conjugate_gradient", "cg_while")], # "obj_opt_minimize_step/while/conjugate_gradient/cg_while"], exclude_keywords=False) proj_ls_flops = getComputationalCostInFlops(self.graph, keywords=[ ("opt_minimize_step", "proj_ls_linesearch") ], exclude_keywords=False) opt_only_flops = getComputationalCostInFlops(self.graph, keywords=["opt"], exclude_keywords=False) flops_without_cg_ls = total_flops - cg_while_flops - proj_ls_flops d = { "total_flops": total_flops, "obj_cg_flops": cg_while_flops, "probe_cg_flops": 0, "obj_proj_ls_flops": proj_ls_flops, "probe_proj_ls_flops": 0, "obj_only_flops": opt_only_flops, "probe_only_flops": 0, "flops_without_cg_ls": flops_without_cg_ls } return d
def getFlopsPerIter(self) -> dict: if self.n_validation > 0: e = RuntimeWarning("n_validation > 0 can give misleading flops.") logger.warning(e) if hasattr(self, "session"): e = RuntimeWarning( "Calculating computational cost after previously initializing the static graph can " + "include inessential calculations (like training and/or validation loss value) " + "and thus give misleading results.") logger.warning(e) else: with self.graph.as_default(): config = tf.ConfigProto() config.gpu_options.allow_growth = True logger.info("Initializing the session.") self.session = tf.Session(config=config) self.session.run(tf.global_variables_initializer()) for opt in self.optimizers: self.session.run(opt.minimize_op) total_flops = getComputationalCostInFlops(self.graph) ls_flops = getComputationalCostInFlops(self.graph, keywords=[ ("opt_minimize_step", "backtracking_linesearch") ], exclude_keywords=False) opt_only_flops = getComputationalCostInFlops(self.graph, keywords=["opt"], exclude_keywords=False) flops_without_ls = total_flops - ls_flops d = { "total_flops": total_flops, "obj_ls_flops": ls_flops, "probe_ls_flops": 0, "obj_only_flops": opt_only_flops, "probe_only_flops": 0, "flops_without_ls": flops_without_ls } return d
def __init__(self, wavelength: float, pixel_size: Tuple[float, ...], shape: Tuple[int, ...], n_photons: float, defocus_dist: float = 0, center_dist: Tuple[float, float] = (0, 0), width_dist: Tuple[float, float] = (0, 0), center_npix: Tuple[int, int] = None, width_npix: Tuple[int, int] = None, check_propagation_with_gaussian_fit: bool = False, apodize: bool = False) -> None: self.wavelength = wavelength self.shape = shape self.pixel_size = pixel_size self.n_photons = n_photons self.defocus_dist = defocus_dist self.center_dist = center_dist self.width_dist = width_dist self.apodize = apodize if center_npix is not None: logger.warning( 'If center_npix is supplied, then any supplied center_dist is ignored.' ) self.center_npix = center_npix self.center_dist = np.array(center_npix) * np.array( self.pixel_size) if width_npix is not None: logger.warning( 'If width_npix is supplied, then any supplied width_dist is ignored.' ) self.width_npix = width_npix self.width_dist = np.array(width_npix) * np.array(self.pixel_size) #wavefront_array = np.zeros((npix, npix), dtype='complex64') #self.wavefront = propagators.Wavefront(wavefront_array, self.wavefront = Wavefront(np.zeros(shape), wavelength=wavelength, pixel_size=pixel_size) self.photons_flux = n_photons / (shape[-1] * shape[-2]) self.check_propagation_with_gaussian_fit = check_propagation_with_gaussian_fit
def __init__(self, wavelength: float = 1.5e-10, obj: Obj = None, obj_params: dict = None, probe: Probe = None, probe_params: dict = None, scan_grid: ScanGrid = None, scan_grid_params: dict = None, detector: Detector = None, detector_params: dict = None, poisson_noise: bool = True, random_shift_pix: Tuple[int, int] = None, upsampling_factor: int = 1, background_scaling: float=1e-8, background_constant: float=0) -> None: self.wavelength = wavelength self.upsampling_factor = upsampling_factor self.poisson_noise = poisson_noise self.background_scaling = background_scaling self.background_constant = background_constant if obj or probe or scan_grid or detector: logger.warning("If one (or all) of obj, probe, scan_grid, or detector is supplied, " + "then the corresponding _params parameter is ignored.") if detector is not None: self.detector = detector self._detector_params = {} else: detector_params = {} if detector_params is None else detector_params self._detector_params = DetectorParams(**detector_params) self.detector = Detector(**dt.asdict(self._detector_params)) detector_support_size = np.asarray(self.detector.pixel_size) * self.detector.shape obj_pixel_size = self.wavelength * self.detector.obj_dist / (detector_support_size * self.upsampling_factor) probe_shape = np.array(self.detector.shape) * self.upsampling_factor if obj is not None: if obj.pixel_size is not None and np.any(obj.pixel_size != obj_pixel_size): e = ValueError("Mismatch between the provided pixel size and the pixel size calculated from scan " + "parameters.") logger.error(e) raise e obj.pixel_size = obj_pixel_size else: obj_params = {} if obj_params is None else obj_params self._obj_params = ObjParams(**obj_params) self.obj = Simulated2DObj(**dt.asdict(self._obj_params)) if probe is not None: check = (np.any(probe.shape != probe_shape) or (probe.wavelength != self.wavelength) or np.any(probe.pixel_size != obj_pixel_size)) if check: e = ValueError("Supplied probe parameters do not match with supplied scan and detector parameters.") logger.error(e) raise e self.probe = probe else: probe_params = {} if probe_params is None else probe_params self._probe_params = ProbeParams(**probe_params) self.probe = FocusCircularProbe(wavelength=wavelength, pixel_size=obj_pixel_size, shape=probe_shape, **dt.asdict(self._probe_params)) if scan_grid is not None: self.scan_grid = scan_grid self._scan_grid_params = None else: scan_grid_params = {} if scan_grid_params is None else scan_grid_params self._scan_grid_params = GridParams(**scan_grid_params) self.scan_grid = RectangleGrid(obj_w_border_shape=self.obj.bordered_array.shape, probe_shape=self.probe.shape, obj_pixel_size=obj_pixel_size, **dt.asdict(self._scan_grid_params)) #if self._scan_grid_params.random_shift_dist is not None: # raise if random_shift_pix is not None: self.scan_grid = self._updateGridWithRandomPixShifts(self.scan_grid, random_shift_pix) #elif self._scan_grid_params.random_shift_ratio is not None: # raise self.scan_grid.checkOverlap() self._calculateDiffractionPatterns()
def __init__(self, wavelength: float = 1.5e-10, obj: Obj = None, probe_3d: Probe3D = None, scan_grid: BraggPtychoGrid = None, detector: Detector = None, poisson_noise: bool = True, upsampling_factor: int = 1) -> None: self.wavelength = wavelength self.upsampling_factor = upsampling_factor self.poisson_noise = poisson_noise two_theta = scan_grid.two_theta if scan_grid is not None else AngleGridParams( ).two_theta if detector is not None: logger.info("Using supplied detector info.") self.detector = detector else: logger.info("Creating new detector.") self.detector = Detector(**dt.asdict(DetectorParams())) det_3d_shape = (1, *self.detector.shape) obj_xz_nyquist_support = self.wavelength * self.detector.obj_dist / np.array( self.detector.pixel_size) obj_xz_pixel_size = obj_xz_nyquist_support / ( np.array(self.detector.shape) * self.upsampling_factor) # The y pixel size is very ad-hoc obj_y_pixel_size = obj_xz_nyquist_support[1] / self.detector.shape[1] obj_pixel_size = (obj_y_pixel_size, *obj_xz_pixel_size) if obj is not None: logger.info("Using supplied object.") self.obj = obj else: self.obj = self.createObj(dt.asdict(ObjParams()), det_3d_shape, obj_pixel_size, upsampling_factor) probe_xz_shape = np.array(self.detector.shape) * self.upsampling_factor if probe_3d is not None: logger.info("Using supplied 3d probe values.") self.probe_3d = probe_3d # Note that the probe y shape cannot be determined from the other supplied parameters (I think). if (np.any(probe_3d.shape[1:] != probe_xz_shape) or (probe_3d.wavelength != self.wavelength) or np.any(probe_3d.pixel_size != obj_pixel_size)): e = ValueError( f"Mismatch between the supplied probe and the supplied scan parameters." ) logger.error(e) raise e else: self.probe_3d = self.createProbe3D(dt.asdict(Probe2DParams()), wavelength, np.pi / 2 - two_theta, probe_xz_shape, obj_pixel_size) rotate_angle = np.pi / 2 - two_theta probe_y_pix_before_rotation = self.probe_3d.shape[0] * np.cos( rotate_angle) // 1 pady, bordered_obj_ypix = self.calculateObjBorderAfterRotation( rotate_angle, probe_y_pix_before_rotation, self.obj.shape[0], self.obj.shape[1]) if self.obj.bordered_array.shape[0] < bordered_obj_ypix: logger.warning( "Adding zero padding to the object in the y-direction so that the overall object y-width " + "covers the entirety of the feasible probe positions.") self.obj.border_shape = ((pady, pady), self.obj.border_shape[1], self.obj.border_shape[2]) if scan_grid is not None: logger.info("Using supplied scan grid.") self.scan_grid = scan_grid else: elems_yz = lambda tup: np.array(tup)[[0, 2]] scan_grid_params_dict = dt.asdict( Scan2DGridParams(obj_pixel_size=elems_yz(self.obj.pixel_size))) self.scan_grid = self.createScanGrid( scan_grid_params_dict, dt.asdict(AngleGridParams()), elems_yz(self.obj.bordered_array.shape), elems_yz(self.probe_3d.shape)) self._calculateDiffractionPatterns()
def __init__(self, wavelength: float = 1.5e-10, obj: Obj = None, obj_params: dict = {}, probe: Probe = None, probe_params: dict = {}, scan_grid: ScanGrid = None, scan_grid_params: dict = {}, detector: Detector = None, detector_params: dict = {}, poisson_noise: bool = True, upsampling_factor: int = 1) -> None: self.wavelength = wavelength self.upsampling_factor = upsampling_factor self.poisson_noise = poisson_noise if obj or probe or scan_grid or detector: logger.warning( "If one (or all) of obj, probe, scan_grid, or detector is supplied, " + "then the corresponding _params parameter is ignored.") if detector is not None: self.detector = detector self._detector_params = {} else: self._detector_params = DetectorParams(**detector_params) self.detector = Detector(**dt.asdict(self._detector_params)) obj_pixel_size = np.array( self.detector.pixel_size) * self.upsampling_factor probe_shape = np.array(self.detector.shape) * self.upsampling_factor if obj is not None: if obj.pixel_size is not None and np.any( obj.pixel_size != obj_pixel_size): e = ValueError( "Mismatch between the provided pixel size and the pixel size calculated from scan " + "parameters.") logger.error(e) raise e obj.pixel_size = obj_pixel_size else: self._obj_params = ObjParams(**obj_params) self.obj = Simulated2DObj(**dt.asdict(self._obj_params)) if probe is not None: check = (np.any(probe.shape != probe_shape) or (probe.wavelength != self.wavelength) or np.any(probe.pixel_size != obj_pixel_size)) if check: e = ValueError( "Supplied probe parameters do not match with supplied scan and detector parameters." ) logger.error(e) raise e self.probe = probe else: self._probe_params = ProbeParams(**probe_params) self.probe = GaussianSpeckledProbe(wavelength=wavelength, pixel_size=obj_pixel_size, shape=probe_shape, **dt.asdict(self._probe_params)) if scan_grid is not None: self.scan_grid = scan_grid self._scan_grid_params = None else: self._scan_grid_params = GridParams(**scan_grid_params) self.scan_grid = RectangleGrid( obj_w_border_shape=self.obj.bordered_array.shape, probe_shape=self.probe.shape, obj_pixel_size=obj_pixel_size, **dt.asdict(self._scan_grid_params)) self.scan_grid.checkOverlap() self._calculateDiffractionPatterns()