Пример #1
0
    def _checkPropGaussianFit(self) -> bool:
        """Experimental feature. To be used primarily for any debugging."""
        from ptychoSampling.wavefront import checkPropagationType

        logger.warn(
            "the check for propagation type is an experimental feature.")
        self._calculateGaussianFit()
        feature_size = np.array(self._gaussian_fwhm) * 2

        propagation_type = checkPropagationType(
            wavelength=self.wavelength,
            prop_dist=self.defocus_dist,
            source_pixel_size=self.pixel_size,
            max_feature_size=feature_size)
        errors = {
            -1:
            "Propagation type is different along the x and y directions. This is not supported.",
            0:
            "Defocus distance is too small (Fresnel number too high) for transfer function method.",
            2: "Defocus distance too large. Only near field defocus supported."
        }
        if propagation_type in errors:
            e = ValueError(errors[propagation_type])
            logger.error(e)
            raise e
        return True
Пример #2
0
    def _calculateDiffractionPatterns(self):
        wv = self.probe.wavefront
        intensities_all = []
        self._transfer_function = None
        ny, nx = self.obj.bordered_array.shape

        if self.scan_grid.subpixel_scan:
            e = ValueError("Subpixel scan not supported for nearfield.")
            logger.error(e)
            raise e

        for i, (py, px) in enumerate(self.scan_grid.positions_pix):
            exit_wave = wv.ifftshift
            exit_wave[py:py + ny, px:px + nx] *= self.obj.bordered_array
            exit_wave = exit_wave.fftshift
            if self._transfer_function is None:
                self._transfer_function = np.zeros_like(exit_wave)
                det_wave = exit_wave.propTF(
                    prop_dist=self.detector.obj_dist,
                    transfer_function=self._transfer_function)
            else:
                det_wave = exit_wave.propTF(
                    reuse_transfer_function=True,
                    transfer_function=self._transfer_function)
            intensities_all.append(det_wave.intensities)

        self.intensities = np.random.poisson(
            intensities_all) if self.poisson_noise else np.array(
                intensities_all)
Пример #3
0
 def _checkAttr(self, attr_to_check, attr_this):
     if not hasattr(self, attr_to_check):
         e = AttributeError(
             f"First attach a {attr_to_check} before attaching {attr_this}."
         )
         logger.error(e)
         raise e
Пример #4
0
    def attachTensorflowOptimizerForVariable(
            self,
            variable_name: str,
            optimizer_type: str,
            optimizer_init_args: dict = None,
            optimizer_minimize_args: dict = None,
            initial_update_delay: int = 0,
            update_frequency: int = 1,
            checkpoint_frequency: int = 100):
        """Attach an optimizer for the specified variable.

        Parameters
        ----------
        variable_name : str
            Name (string) associated with the chosen variable (to be optimized) in the forward model.
        optimizer_type : str
            Name of the standard optimizer chosen from availabe options in options.Options.
        optimizer_init_args : dict
            Dictionary containing the key-value pairs required for the initialization of the desired optimizer.
        optimizer_minimize_args : dict
            Dictionary containing the key-value pairs required to define the minimize operation for the desired
            optimizer.
        initial_update_delay : int
            Number of iterations to wait before the minimizer is first applied. Defaults to 0.
        update_frequency : int
            Number of iterations in between minimization calls. Defaults to 1.
        checkpoint_frequency : int
            Number of iterations between creation of checkpoints of the optimizer. Not implemented.
        """
        optimization_all = ptychoSampling.reconstruction.options.OPTIONS[
            "tf_optimization_methods"]
        self._checkConfigProperty(optimization_all, optimizer_type)
        self._checkAttr("_train_loss_t", "optimizer")
        if variable_name not in self.fwd_model.model_vars:
            e = ValueError(
                f"{variable_name} is not a supported variable in {self.fwd_model}"
            )
            logger.error(e)
            raise e

        var = self.fwd_model.model_vars[variable_name]["variable"]

        if optimizer_minimize_args is None:
            optimizer_minimize_args = {}
        elif ("loss"
              in optimizer_minimize_args) or ("var_list"
                                              in optimizer_minimize_args):
            warning = (
                "Target loss and optimization variable are assigned by default. "
                +
                "If custom processing is desired, use _attachCustomOptimizerForVariable directly."
            )
            logger.warning(warning)

        optimizer_minimize_args["loss"] = self._train_loss_t
        optimizer_minimize_args["var_list"] = [var]

        self._attachCustomOptimizerForVariable(
            optimization_all[optimizer_type], optimizer_init_args,
            optimizer_minimize_args, initial_update_delay, update_frequency)
Пример #5
0
    def addCustomMetricToDataLog(self,
                                 title: str,
                                 tensor: tf.Tensor,
                                 log_epoch_frequency: int = 1,
                                 registration_ground_truth: np.ndarray = None,
                                 registration: bool = True,
                                 normalized_lse: bool = False):
        """Registration metric type only applies if registration ground truth is not none."""
        if registration_ground_truth is None:
            self.datalog.addCustomTensorMetric(
                title=title,
                tensor=tensor,
                log_epoch_frequency=log_epoch_frequency)
        else:
            if registration and normalized_lse:
                e = ValueError(
                    "Only one of 'registration' or 'normalized lse' should be true."
                )
                logger.error(e)

            self.datalog.addCustomTensorMetric(
                title=title,
                tensor=tensor,
                registration=registration,
                normalized_lse=normalized_lse,
                log_epoch_frequency=log_epoch_frequency,
                true=registration_ground_truth)
Пример #6
0
    def __init__(self,
                 *args: Any,
                 focal_length: Optional[float] = None,
                 aperture_radius: Optional[float] = None,
                 oversampling: bool = True,
                 oversampling_npix: int = 1024,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        if ((self.width_dist[0] != self.width_dist[1])
                or (self.pixel_size[0] != self.pixel_size[1])
                or (self.shape[0] != self.shape[1])):
            e = ValueError(
                "Only a circularly symmetric focussed probe is supported. Require equal x and y widths."
            )
            logger.error(e)
            raise e

        if self.center_dist[0] > 0:
            e = ValueError(
                "Only a centrally located focussed probe is supported (i.e. center at origin). If "
                +
                "necessary, we can simply customize the wavefront manually (pad and roll)."
            )
            logger.error(e)
            raise e

        self.focal_length = focal_length
        self.aperture_radius = aperture_radius
        self.oversampling = oversampling
        self.oversampling_npix = oversampling_npix
        self._calculateWavefront()
Пример #7
0
 def _checkConfigProperty(options: dict, key_to_check: str):
     if key_to_check not in options:
         e = ValueError(
             f"{key_to_check} is not currently supported. " +
             f"Check if {key_to_check} exists as an option among {options} in options.py"
         )
         logger.error(e)
         raise e
Пример #8
0
 def __post_init__(self):
     if self.registration and self.true is None:
         e = ValueError("Require true data for registration.")
         logger.error(e)
         raise e
     if not self.registration:
         self.columns = [self.title]
     else:
         appends = ['_err']  # , '_shift', '_phase']
         self.columns = [self.title + string for string in appends]
Пример #9
0
    def _calculateWavefront(self) -> None:
        """Calculating the airy pattern then propagating it by `defocus_dist`."""
        if self.width_dist[0] > 0:
            logger.warning(
                "Warning: if width at the focus is supplied, " +
                "any supplied focal_length and aperture radius are ignored.")
            self.aperture_radius = None
            self.focal_length = None
            # Assuming resolution equal to pixel pitch and focal length of 10.0 m.
            focal_length = 10.0
            focus_radius = self.width_dist[0] / 2
            # note that jinc(1,1) = 1.22 * np.pi
            aperture_radius = (self.wavelength * 1.22 * np.pi * focal_length /
                               2 / np.pi / focus_radius)
        else:
            if None in [self.aperture_radius, self.focal_length]:
                e = ValueError(
                    'Either focus_radius_npix or BOTH aperture_radius and focal_length must be supplied.'
                )
                logger.error(e)
                raise e
            aperture_radius = self.aperture_radius
            focal_length = self.focal_length

        npix_oversampled = max(
            self.oversampling_npix,
            self.shape[-1]) if self.oversampling else self.shape[0]
        pixel_pitch_aperture = self.wavelength * focal_length / (
            npix_oversampled * self.pixel_size[0])

        x = np.arange(-npix_oversampled // 2,
                      npix_oversampled // 2) * pixel_pitch_aperture

        r = np.sqrt(x**2 + x[:, np.newaxis]**2).astype('float32')
        circ_wavefront = np.zeros(r.shape)
        circ_wavefront[r < aperture_radius] = 1.0
        circ_wavefront[r == aperture_radius] = 0.5

        probe_vals = np.fft.fftshift(
            np.fft.fft2(np.fft.fftshift(circ_wavefront), norm='ortho'))

        n1 = npix_oversampled // 2 - self.shape[0] // 2
        n2 = npix_oversampled // 2 + self.shape[1] // 2

        scaling_factor = np.sqrt(self.n_photons /
                                 np.sum(np.abs(probe_vals)**2))
        wavefront_array = probe_vals[n1:n2, n1:n2].astype(
            'complex64') * scaling_factor
        wavefront_array = np.fft.fftshift(wavefront_array)
        #self.wavefront = self.wavefront.update(array=wavefront_array)
        self.wavefront = Wavefront(wavefront_array,
                                   wavelength=self.wavelength,
                                   pixel_size=self.pixel_size)
        self._propagateWavefront()
Пример #10
0
 def __init__(self, shape=(128, 128),
              border_shape=((32, 32), (32, 32)),
              border_const=1.0,
              mod_range=1.0,
              phase_range=np.pi):
     if len(shape) != 2:
         e = ValueError('Supplied shape is not 2d.')
         logger.error(e)
         raise e
     super().__init__(shape, border_shape, border_const, mod_range, phase_range)
     self._createObj()
Пример #11
0
    def _sanityChecks(self):
        sanity_checks = {
            'probe_det':
            self._probe.npix == self._detector.npix * self.upsampling_factor,
            'probe_obj': self._probe.pixel_size == self._obj.pixel_size,
            'obj_grid': self._obj.pixel_size == self._scan_grid.obj_pixel_size,
            'probe_grid': self._probe.npix == self._scan_grid.probe_npix
        }

        for s, v in sanity_checks.items():
            if not v:
                e = ValueError(f"Mismatch in supplied {s} parameters.")
                logger.error(e)
                raise e
Пример #12
0
    def __post_init__(self) -> None:
        try:
            _ = np.array(
                self.intensity_patterns) + self.background_intensity_level
        except Exception as e:
            e2 = ValueError(
                'Invalid format for background intensity level. Should be compatible to the '
                +
                'supplied intensity_patterns parameter value  through a numpy broadcasting operation.'
            )
            logger.error([e, e2])
            raise e2 from e

        for key in self.additional_params:
            object.__setattr__(self, key, self.additional_params[key])
Пример #13
0
 def _register(test: np.ndarray, true: np.ndarray) -> float:
     if len(test.shape) == 2:
         registration_fn = _register_translation_2d
     elif len(test.shape) == 3:
         registration_fn = _register_translation_3d
     else:
         e = ValueError(
             "Subpixel registration only available for 2d and 3d objects.")
         logger.error(e)
         raise e
     shift, err, phase = registration_fn(test, true, upsample_factor=10)
     shift, err, phase = registration_fn(test * np.exp(-1j * phase),
                                         true,
                                         upsample_factor=10)
     return err
Пример #14
0
    def checkOverlap(self, overlap_ratio: float = 0.5) -> bool:
        if not self.full_field_probe:
            overlap_nx = overlap_ratio * self.probe_shape[1]
            overlap_ny = overlap_ratio * self.probe_shape[0]
        else:
            overlap_nx = overlap_ratio * self.obj_w_border_shape[1]
            overlap_ny = overlap_ratio * self.obj_w_border_shape[0]

        for p in self.positions_pix:
            differences = np.abs(self.positions_pix - p)
            ydiffs = differences[:, 0]
            xdiffs = differences[:, 1]
            xmin = np.min(xdiffs[xdiffs > 0])
            ymin = np.min(ydiffs[ydiffs > 0])

            if xmin > overlap_nx or ymin > overlap_ny:
                e = ValueError(
                    "Insufficient overlap between adjacent scan positions.")
                logger.error(e)
                raise e
        return True
Пример #15
0
    def __init__(self, mesh_shape: Tuple[int, int, int] = (128, 128, 128),
                 mod_const: float = 0.5,
                 border_shape=((0,0), (0,0), (0,0)),
                 border_const=0.0,
                 pixel_size=None):
        if len(mesh_shape) != 3:
            e = ValueError('Supplied shape is not 3d.')
            logger.error(e)
            raise e

        self.mod_constant = mod_const

        self.border_const = border_const
        self._border_shape = border_shape
        self.pixel_size = pixel_size

        self.mod_range = None
        self.phase_range = None

        self._createObj()
        self.array *= self.mod_constant
Пример #16
0
    def __init__(self,
                 obj: Obj,
                 probe: Probe3D,
                 scan_grid: BraggPtychoGrid,
                 exit_wave_axis: str = "y",
                 upsampling_factor: int = 1,
                 setup_second_order: bool = False,
                 dtype: str = "float32"):
        if scan_grid.full_field_probe:
            e = ValueError(
                "Full field probe not supported for Bragg ptychography.")
            logger.error(e)
            raise e
        if scan_grid.grid2d_axes != ("y", "z"):
            e = ValueError(
                "Only supports the case where the ptychographic scan is on the yz-plane."
            )
            logger.error(e)
            raise e
        if exit_wave_axis != 'y':
            e = ValueError(
                "Only supports the case where the exit waves are output along the y-direction."
            )
            logger.error(e)
            raise e

        super().__init__(obj, probe, scan_grid, upsampling_factor,
                         setup_second_order, dtype)
        logger.info("Creating the phase modulations for the scan angles.")

        with tf.device("/gpu:0"):
            self._probe_phase_modulations_all_t = tf.constant(
                self._getProbePhaseModulationsStack(), dtype='complex64')
            self._full_rc_positions_indices_t = tf.constant(
                scan_grid.full_rc_positions_indices, dtype='int64')
Пример #17
0
    def _setObjArrayValues(self, values: Optional[np.ndarray] = None) -> None:
        """Set the obj transmission function and add the border.

        Performs sanity checks for the 'shape' and '_border_shape' parameters supplied when the class is created. The
        'shape' parameter should be tuple-like and composed of integers, formatted so that 'numpy.zeros' accepts it
        as an argument for the 'shape' parameter. The '_border_shape' should be formatted so that 'numpy.pad' accepts it
        as an argument for the 'pad' parameter.

        Sets the values for the 'array' and 'bordered_array' attributes.
        Parameters
        ----------
        values : array_like, optional
            Obj array values. For the default value 'None', the function creates an array of zeros.
        """

        if values is None:
            try:
                values = np.zeros(self.shape)
            except Exception as e:
                e2 = ValueError("Error in input obj shape.")
                logger.error([e, e2])
                raise e2 from e

        #self._array = values
        try:
            self.bordered_array = np.pad(
                values,  #self._array,
                self._border_shape,
                mode='constant',
                constant_values=self.border_const)
        except Exception as e:
            e2 = ValueError("Error in border specifications.")
            logger.error([e, e2])
            raise e2 from e
        array_slices = tuple(
            slice(b[0], self.shape[i] + b[0])
            for i, b in enumerate(self._border_shape))

        # This is only a view to bordered_array. Changing one changes the other.
        self._array = self.bordered_array[array_slices]
Пример #18
0
    def __init__(self,
                 obj_w_border_shape: Tuple[int, int],
                 probe_shape: Tuple[int, int],
                 obj_pixel_size: Tuple[float, float],
                 step_dist: Tuple[float, float],
                 subpixel_scan: bool = False,
                 scan_grid_boundary_pix: np.ndarray = None,
                 full_field_probe: bool = False):
        self.obj_w_border_shape = obj_w_border_shape
        self.obj_pixel_size = obj_pixel_size
        self.probe_shape = probe_shape
        self.step_dist = np.array(step_dist)
        self.subpixel_scan = subpixel_scan
        self.full_field_probe = full_field_probe

        if scan_grid_boundary_pix is None:
            if not self.full_field_probe:
                self.scan_grid_boundary_pix = np.array(
                    [[0, self.obj_w_border_shape[0]],
                     [0, self.obj_w_border_shape[1]]])
            else:
                self.scan_grid_boundary_pix = np.array([[0, probe_shape[0]],
                                                        [0, probe_shape[1]]])
        else:
            try:
                _ = np.reshape(scan_grid_boundary_pix, (2, 2))
            except Exception as e:
                e2 = ValueError(
                    "scan_grid_boundary_pix should contain integers and have shape [[y_min, y_max], "
                    + "[x_min, x_max]]")
                logger.error([e, e2])
                raise e2 from e
            self.scan_grid_boundary_pix = scan_grid_boundary_pix

        self.positions_pix = []
        self.positions_subpix = []
        self.positions_dist = []
Пример #19
0
    def __init__(self,
                 *args: int,
                 loss_type: str = "least_squared",
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 update_delay_probe: int = 0,
                 update_delay_obj: int = 0,
                 reconstruct_probe: bool = True,
                 registration_log_frequency: int = 10,
                 opt_init_extra_kwargs: dict = None,
                 obj_abs_proj: bool = True,
                 loss_init_extra_kwargs: dict = None,
                 r_factor_log: bool = True,
                 apply_precond: bool = False,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        if self.training_batch_size != self.n_train:
            e = ValueError(
                "Conjugate gradient reconstruction does not support minibatch reconstruction."
            )
            logger.error(e)
            raise e

        logger.info('attaching fwd model...')
        self._attachCustomForwardModel(JointFarfieldForwardModelT,
                                       obj_abs_proj=obj_abs_proj)

        logger.info('creating loss fn...')
        self.attachLossFunctionSecondOrder(
            loss_type, loss_init_extra_kwargs=loss_init_extra_kwargs)

        logger.info("creating optimizers...")
        if opt_init_extra_kwargs is None:
            opt_init_extra_kwargs = {}

        self._preconditioner = None
        if apply_precond:
            with self.graph.as_default():
                loss_data_type = self._loss_method.data_type
                hessian_t = (
                    tf.ones_like(self._batch_train_predictions_t) *
                    self._train_hessian_fn(self._batch_train_predictions_t))
                hessian_t = tf.reshape(hessian_t, [-1, *self.probe.shape])

                obj_scaling = self._getObjScaling(loss_data_type, hessian_t)
                probe_scaling = self._getProbeScaling(loss_data_type,
                                                      hessian_t)
                scaling_both = tf.concat((obj_scaling, probe_scaling), axis=0)
                scaling_both = tf.concat((scaling_both, scaling_both), axis=0)
                zero_condition = tf.less(scaling_both,
                                         1e-10 * tf.reduce_max(scaling_both))
                zero_case = tf.ones_like(scaling_both) / (
                    1e-8 * tf.reduce_max(scaling_both))
                self._preconditioner = tf.where(zero_condition, zero_case,
                                                1 / scaling_both)

        opt_init_args = {
            "input_var": self.fwd_model.joint_v,
            "predictions_fn": self._predictions_fn,
            "loss_fn": self._train_loss_fn,
            "name": "opt",
            "diag_precondition_t": self._preconditioner
        }
        opt_init_args.update(opt_init_extra_kwargs)
        self._attachCustomOptimizerForVariable(
            ConjugateGradientOptimizer, optimizer_init_args=opt_init_args)
        self.addCustomMetricToDataLog(
            title="ls_iters",
            tensor=self.optimizers[0]._optimizer._linesearch_steps,
            log_epoch_frequency=1)
        self.addCustomMetricToDataLog(
            title="alpha",
            tensor=self.optimizers[0]._optimizer._linesearch._alpha,
            log_epoch_frequency=1)

        if obj_array_true is not None:
            self.addCustomMetricToDataLog(
                title="obj_error",
                tensor=self.fwd_model.obj_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration_ground_truth=obj_array_true)
        if reconstruct_probe and (probe_wavefront_true is not None):
            self.addCustomMetricToDataLog(
                title="probe_error",
                tensor=self.fwd_model.probe_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration_ground_truth=probe_wavefront_true)
        self._addRFactorLog(r_factor_log, registration_log_frequency)
Пример #20
0
    def __init__(self,
                 *args: int,
                 loss_type: str = "least_squared",
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 max_cg_iter: int = 100,
                 min_cg_tol: float = 0.1,
                 registration_log_frequency: int = 10,
                 both_registration_nlse: bool = True,
                 opt_init_extra_kwargs: dict = None,
                 obj_abs_proj: bool = True,
                 obj_abs_max: float = 1.0,
                 probe_abs_max: float = None,
                 loss_init_extra_kwargs: dict = None,
                 r_factor_log: bool = True,
                 update_delay_probe: int = 0,
                 reconstruct_probe: bool = True,
                 apply_diag_mu_scaling: bool = True,
                 apply_precond: bool = False,
                 apply_precond_and_scaling: bool = False,
                 stochastic_diag_estimator_type: str = None,
                 stochastic_diag_estimator_iters: int = 1,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        if apply_precond_and_scaling:
            logger.info(
                "Overriding any set values for apply_precond and apply_diag_mu_scaling"
            )
            apply_precond = True
            apply_diag_mu_scaling = True

        if stochastic_diag_estimator_type is not None:
            if apply_diag_mu_scaling or apply_precond:
                e = ValueError(
                    "Cannot use analytical precond and diag ggn elems if stochastic calculation is enabled."
                )
                logger.error(e)
                raise e

        if self.training_batch_size != self.n_train:
            e = ValueError(
                "LMA reconstruction does not support minibatch reconstruction."
            )
            logger.error(e)
            raise e

        logger.info('attaching fwd model...')
        self._attachCustomForwardModel(JointFarfieldForwardModelT,
                                       obj_abs_proj=obj_abs_proj,
                                       obj_abs_max=obj_abs_max,
                                       probe_abs_max=probe_abs_max)

        logger.info('creating loss fn...')

        print('Loss init args', loss_init_extra_kwargs)
        self.attachLossFunctionSecondOrder(
            loss_type, loss_init_extra_kwargs=loss_init_extra_kwargs)

        logger.info("creating optimizer...")
        if opt_init_extra_kwargs is None:
            opt_init_extra_kwargs = {}

        if apply_diag_mu_scaling or apply_precond:

            with self.graph.as_default():

                loss_data_type = self._loss_method.data_type
                hessian_t = (
                    tf.ones_like(self._batch_train_predictions_t) *
                    self._train_hessian_fn(self._batch_train_predictions_t))
                hessian_t = tf.reshape(hessian_t, [-1, *self.probe.shape])
                print(hessian_t)
                self._obj_scaling = self._getObjScaling(
                    loss_data_type, hessian_t)
                self._probe_scaling = self._getProbeScaling(
                    loss_data_type, hessian_t)
                scaling_both = tf.concat(
                    (self._obj_scaling, self._probe_scaling), axis=0)

                if apply_diag_mu_scaling:
                    self._joint_scaling_t = tf.concat(
                        (scaling_both, scaling_both), axis=0)
                    optimizer = ScaledLMAOptimizer
                    opt_init_extra_kwargs[
                        'diag_mu_scaling_t'] = self._joint_scaling_t
                if apply_precond:
                    self._joint_precond_t = tf.concat(
                        (scaling_both, scaling_both), axis=0)
                    optimizer = PCGLMAOptimizer
                    opt_init_extra_kwargs[
                        'diag_precond_t'] = self._joint_precond_t
                    if apply_diag_mu_scaling:
                        optimizer = ScaledPCGLMAOptimizer
                #
        else:
            optimizer = LMAOptimizer

        opt_init_args = {
            "input_var": self.fwd_model.joint_v,
            "predictions_fn": self._predictions_fn,
            "loss_fn": self._train_loss_fn,
            "diag_hessian_fn": self._train_hessian_fn,
            "name": "opt",
            "max_cg_iter": max_cg_iter,
            "min_cg_tol": min_cg_tol,
            "stochastic_diag_estimator_type": stochastic_diag_estimator_type,
            "stochastic_diag_estimator_iters": stochastic_diag_estimator_iters,
            "assert_tolerances": False
        }
        opt_init_args.update(opt_init_extra_kwargs)

        self._attachCustomOptimizerForVariable(
            optimizer, optimizer_init_args=opt_init_args)
        self.addCustomMetricToDataLog(
            title="cg_iters",
            tensor=self.optimizers[0]._optimizer._total_cg_iterations,
            log_epoch_frequency=1)
        self.addCustomMetricToDataLog(
            title="ls_iters",
            tensor=self.optimizers[0]._optimizer._total_proj_ls_iterations,
            log_epoch_frequency=1)
        self.addCustomMetricToDataLog(title="proj_iters",
                                      tensor=self.optimizers[0]._optimizer.
                                      _projected_gradient_iterations,
                                      log_epoch_frequency=1)
        self.addCustomMetricToDataLog(title="mu",
                                      tensor=self.optimizers[0]._optimizer._mu,
                                      log_epoch_frequency=1)

        self._registration_log_frequency = registration_log_frequency
        if obj_array_true is not None:
            self._setObjRegistration(
                obj_array_true, both_registration_nlse=both_registration_nlse)
        if probe_wavefront_true is not None:
            self._setProbeRegistration(
                probe_wavefront_true,
                both_registration_nlse=both_registration_nlse)
        self._addRFactorLog(r_factor_log, registration_log_frequency)
Пример #21
0
    def __init__(self,
                 wavelength: float = 1.5e-10,
                 obj: Obj = None,
                 probe_3d: Probe3D = None,
                 scan_grid: BraggPtychoGrid = None,
                 detector: Detector = None,
                 poisson_noise: bool = True,
                 upsampling_factor: int = 1) -> None:
        self.wavelength = wavelength
        self.upsampling_factor = upsampling_factor
        self.poisson_noise = poisson_noise

        two_theta = scan_grid.two_theta if scan_grid is not None else AngleGridParams(
        ).two_theta

        if detector is not None:
            logger.info("Using supplied detector info.")
            self.detector = detector
        else:
            logger.info("Creating new detector.")
            self.detector = Detector(**dt.asdict(DetectorParams()))
        det_3d_shape = (1, *self.detector.shape)

        obj_xz_nyquist_support = self.wavelength * self.detector.obj_dist / np.array(
            self.detector.pixel_size)
        obj_xz_pixel_size = obj_xz_nyquist_support / (
            np.array(self.detector.shape) * self.upsampling_factor)

        # The y pixel size is very ad-hoc
        obj_y_pixel_size = obj_xz_nyquist_support[1] / self.detector.shape[1]
        obj_pixel_size = (obj_y_pixel_size, *obj_xz_pixel_size)

        if obj is not None:
            logger.info("Using supplied object.")
            self.obj = obj
        else:
            self.obj = self.createObj(dt.asdict(ObjParams()), det_3d_shape,
                                      obj_pixel_size, upsampling_factor)

        probe_xz_shape = np.array(self.detector.shape) * self.upsampling_factor

        if probe_3d is not None:
            logger.info("Using supplied 3d probe values.")
            self.probe_3d = probe_3d
            # Note that the probe y shape cannot be determined from the other supplied parameters (I think).
            if (np.any(probe_3d.shape[1:] != probe_xz_shape)
                    or (probe_3d.wavelength != self.wavelength)
                    or np.any(probe_3d.pixel_size != obj_pixel_size)):
                e = ValueError(
                    f"Mismatch between the supplied probe and the supplied scan parameters."
                )
                logger.error(e)
                raise e
        else:
            self.probe_3d = self.createProbe3D(dt.asdict(Probe2DParams()),
                                               wavelength,
                                               np.pi / 2 - two_theta,
                                               probe_xz_shape, obj_pixel_size)

        rotate_angle = np.pi / 2 - two_theta
        probe_y_pix_before_rotation = self.probe_3d.shape[0] * np.cos(
            rotate_angle) // 1
        pady, bordered_obj_ypix = self.calculateObjBorderAfterRotation(
            rotate_angle, probe_y_pix_before_rotation, self.obj.shape[0],
            self.obj.shape[1])

        if self.obj.bordered_array.shape[0] < bordered_obj_ypix:
            logger.warning(
                "Adding zero padding to the object in the y-direction so that the overall object y-width "
                + "covers the entirety of the feasible probe positions.")
            self.obj.border_shape = ((pady, pady), self.obj.border_shape[1],
                                     self.obj.border_shape[2])
        if scan_grid is not None:
            logger.info("Using supplied scan grid.")
            self.scan_grid = scan_grid
        else:
            elems_yz = lambda tup: np.array(tup)[[0, 2]]
            scan_grid_params_dict = dt.asdict(
                Scan2DGridParams(obj_pixel_size=elems_yz(self.obj.pixel_size)))
            self.scan_grid = self.createScanGrid(
                scan_grid_params_dict, dt.asdict(AngleGridParams()),
                elems_yz(self.obj.bordered_array.shape),
                elems_yz(self.probe_3d.shape))
        self._calculateDiffractionPatterns()
Пример #22
0
    def __init__(self,
                 wavelength: float = 1.5e-10,
                 obj: Obj = None,
                 obj_params: dict = {},
                 probe: Probe = None,
                 probe_params: dict = {},
                 scan_grid: ScanGrid = None,
                 scan_grid_params: dict = {},
                 detector: Detector = None,
                 detector_params: dict = {},
                 poisson_noise: bool = True,
                 upsampling_factor: int = 1) -> None:

        self.wavelength = wavelength
        self.upsampling_factor = upsampling_factor
        self.poisson_noise = poisson_noise

        if obj or probe or scan_grid or detector:
            logger.warning(
                "If one (or all) of obj, probe, scan_grid, or detector is supplied, "
                + "then the corresponding _params parameter is ignored.")

        if detector is not None:
            self.detector = detector
            self._detector_params = {}
        else:
            self._detector_params = DetectorParams(**detector_params)
            self.detector = Detector(**dt.asdict(self._detector_params))

        obj_pixel_size = np.array(
            self.detector.pixel_size) * self.upsampling_factor
        probe_shape = np.array(self.detector.shape) * self.upsampling_factor

        if obj is not None:
            if obj.pixel_size is not None and np.any(
                    obj.pixel_size != obj_pixel_size):
                e = ValueError(
                    "Mismatch between the provided pixel size and the pixel size calculated from scan "
                    + "parameters.")
                logger.error(e)
                raise e
            obj.pixel_size = obj_pixel_size
        else:
            self._obj_params = ObjParams(**obj_params)
            self.obj = Simulated2DObj(**dt.asdict(self._obj_params))

        if probe is not None:
            check = (np.any(probe.shape != probe_shape)
                     or (probe.wavelength != self.wavelength)
                     or np.any(probe.pixel_size != obj_pixel_size))
            if check:
                e = ValueError(
                    "Supplied probe parameters do not match with supplied scan and detector parameters."
                )
                logger.error(e)
                raise e
            self.probe = probe
        else:
            self._probe_params = ProbeParams(**probe_params)
            self.probe = GaussianSpeckledProbe(wavelength=wavelength,
                                               pixel_size=obj_pixel_size,
                                               shape=probe_shape,
                                               **dt.asdict(self._probe_params))

        if scan_grid is not None:
            self.scan_grid = scan_grid
            self._scan_grid_params = None
        else:
            self._scan_grid_params = GridParams(**scan_grid_params)
            self.scan_grid = RectangleGrid(
                obj_w_border_shape=self.obj.bordered_array.shape,
                probe_shape=self.probe.shape,
                obj_pixel_size=obj_pixel_size,
                **dt.asdict(self._scan_grid_params))

        self.scan_grid.checkOverlap()
        self._calculateDiffractionPatterns()
Пример #23
0
    def __init__(self,
                 wavelength: float = 1.5e-10,
                 obj: Obj = None,
                 obj_params: dict = None,
                 probe: Probe = None,
                 probe_params: dict = None,
                 scan_grid: ScanGrid = None,
                 scan_grid_params: dict = None,
                 detector: Detector = None,
                 detector_params: dict = None,
                 poisson_noise: bool = True,
                 random_shift_pix: Tuple[int, int] = None,
                 upsampling_factor: int = 1,
                 background_scaling: float=1e-8,
                 background_constant: float=0) -> None:

        self.wavelength = wavelength
        self.upsampling_factor = upsampling_factor
        self.poisson_noise = poisson_noise

        self.background_scaling = background_scaling
        self.background_constant = background_constant

        if obj or probe or scan_grid or detector:
            logger.warning("If one (or all) of obj, probe, scan_grid, or detector is supplied, "
                           + "then the corresponding _params parameter is ignored.")

        if detector is not None:
            self.detector = detector
            self._detector_params = {}
        else:
            detector_params = {} if detector_params is None else detector_params
            self._detector_params = DetectorParams(**detector_params)
            self.detector = Detector(**dt.asdict(self._detector_params))

        detector_support_size = np.asarray(self.detector.pixel_size) * self.detector.shape
        obj_pixel_size = self.wavelength * self.detector.obj_dist / (detector_support_size * self.upsampling_factor)

        probe_shape = np.array(self.detector.shape) * self.upsampling_factor

        if obj is not None:
            if obj.pixel_size is not None and np.any(obj.pixel_size != obj_pixel_size):
                e = ValueError("Mismatch between the provided pixel size and the pixel size calculated from scan "
                               + "parameters.")
                logger.error(e)
                raise e
            obj.pixel_size = obj_pixel_size
        else:
            obj_params = {} if obj_params is None else obj_params
            self._obj_params = ObjParams(**obj_params)
            self.obj = Simulated2DObj(**dt.asdict(self._obj_params))

        if probe is not None:
            check = (np.any(probe.shape != probe_shape)
                     or (probe.wavelength != self.wavelength)
                     or np.any(probe.pixel_size != obj_pixel_size))
            if check:
                e = ValueError("Supplied probe parameters do not match with supplied scan and detector parameters.")
                logger.error(e)
                raise e
            self.probe = probe
        else:
            probe_params = {} if probe_params is None else probe_params
            self._probe_params = ProbeParams(**probe_params)
            self.probe = FocusCircularProbe(wavelength=wavelength,
                                        pixel_size=obj_pixel_size,
                                        shape=probe_shape,
                                        **dt.asdict(self._probe_params))

        if scan_grid is not None:
            self.scan_grid = scan_grid
            self._scan_grid_params = None
        else:
            scan_grid_params = {} if scan_grid_params is None else scan_grid_params
            self._scan_grid_params = GridParams(**scan_grid_params)
            self.scan_grid = RectangleGrid(obj_w_border_shape=self.obj.bordered_array.shape,
                                           probe_shape=self.probe.shape,
                                           obj_pixel_size=obj_pixel_size,
                                           **dt.asdict(self._scan_grid_params))
            #if self._scan_grid_params.random_shift_dist is not None:
            #    raise
            if random_shift_pix is not None:
                self.scan_grid = self._updateGridWithRandomPixShifts(self.scan_grid, random_shift_pix)
            #elif self._scan_grid_params.random_shift_ratio is not None:
            #    raise
        self.scan_grid.checkOverlap()
        self._calculateDiffractionPatterns()
Пример #24
0
    def __init__(self, *args: int,
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 obj_abs_proj: bool = True,
                 update_delay_probe: int = 0,
                 update_delay_obj: int = 0,
                 update_delay_aux: int = 0,
                 update_frequency_probe: int = 1,
                 update_frequency_obj: int = 1,
                 update_frequency_aux: int = 1,
                 reconstruct_probe: bool = True,
                 registration_log_frequency: int = 10,
                 r_factor_log: bool = True,
                 loss_type: str = "gaussian",
                 learning_rate_scaling_probe: float = 1.0,
                 learning_rate_scaling_obj: float = 1.0,
                 loss_init_extra_kwargs: dict = None,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)
        if self.training_batch_size != self.n_train:
            e = ValueError("PALM reconstruction does not support minibatch reconstruction.")
            logger.error(e)
            raise e

        if loss_type != "gaussian":
            e = ValueError("PALM reconstruction does not support other loss functions.")
            logger.error(e)
            raise e

        logger.info('attaching fwd model...')
        self._attachCustomForwardModel(FarfieldPALMForwardModel, obj_abs_proj=obj_abs_proj)

        logger.info('creating loss fn...')
        self._attachCustomLossFunction(AuxLossFunctionT, loss_init_extra_kwargs)

        logger.info("create learning rates")
        self._learning_rate_scaling_probe = learning_rate_scaling_probe
        self._learning_rate_scaling_obj = learning_rate_scaling_obj
        self._lr_obj = self._getObjLearningRate() * learning_rate_scaling_obj
        self._lr_probe = self._getProbeLearningRate() * learning_rate_scaling_probe

        logger.info('creating optimizers...')
        aux_new_t = self._getAuxUpdate()
        self._attachCustomOptimizerForVariable(AuxAssignOptimizer,
                                               optimizer_init_args={"aux_v": self.fwd_model.aux_v,
                                                                    "aux_new_t": aux_new_t},
                                               initial_update_delay=update_delay_aux,
                                               update_frequency=update_frequency_aux)

        update_delay_obj = update_delay_obj if update_delay_obj is not None else 0
        self.attachTensorflowOptimizerForVariable("obj",
                                                  optimizer_type="gradient",
                                                  optimizer_init_args={"learning_rate": self._lr_obj},
                                                  initial_update_delay=update_delay_obj,
                                                  update_frequency=update_frequency_obj)

        update_delay_probe = update_delay_probe if update_delay_probe is not None else 0
        if reconstruct_probe:
            self.attachTensorflowOptimizerForVariable("probe",
                                                      optimizer_type="gradient",
                                                      optimizer_init_args={"learning_rate": self._lr_probe},
                                                      initial_update_delay=update_delay_probe,
                                                      update_frequency=update_frequency_probe)

        if obj_array_true is not None:
            self.addCustomMetricToDataLog(title="obj_error",
                                          tensor=self.fwd_model.obj_cmplx_t,
                                          log_epoch_frequency=registration_log_frequency,
                                          registration_ground_truth=obj_array_true)
        if reconstruct_probe and (probe_wavefront_true is not None):
            self.addCustomMetricToDataLog(title="probe_error",
                                          tensor=self.fwd_model.probe_cmplx_t,
                                          log_epoch_frequency=registration_log_frequency,
                                          registration_ground_truth=probe_wavefront_true)
        self._addRFactorLog(r_factor_log, 1)