コード例 #1
0
ファイル: simulation.py プロジェクト: ni-chen/ptychoSampling
    def _calculatePhaseModulationsForRCAngles(self):
        logger.info("Calculating the phase modulations for the rc angles.")

        ttheta = self.scan_grid.two_theta
        domega = self.scan_grid.del_omega

        ki = 2 * np.pi / self.wavelength * np.array(
            [np.cos(ttheta), np.sin(ttheta), 0])
        kf = 2 * np.pi / self.wavelength * np.array([1, 0, 0])
        q = (kf - ki)[:, None]

        ki_new = 2 * np.pi / self.wavelength * np.array([
            np.cos(ttheta + self.scan_grid.rc_angles),
            np.sin(ttheta + self.scan_grid.rc_angles),
            0 * self.scan_grid.rc_angles
        ])
        kf_new = 2 * np.pi / self.wavelength * np.array([
            np.cos(self.scan_grid.rc_angles),
            np.sin(self.scan_grid.rc_angles), 0 * self.scan_grid.rc_angles
        ])
        q_new = kf_new - ki_new
        delta_q = q_new - q
        # Probe dimensions in real space (assumes even shape)
        position_grids = [
            np.arange(-s // 2, s // 2) * ds
            for (s, ds) in zip(self.probe_3d.shape, self.probe_3d.pixel_size)
        ]
        Ry, Rx, Rz = np.meshgrid(*position_grids, indexing='ij')
        phase_modulations_all = np.exp(1j * np.array([
            delta_q[0, i] * Ry + delta_q[1, i] * Rx + delta_q[2, i] * Rz
            for i in range(self.scan_grid.n_rc_angles)
        ]))
        self._phase_modulations_all = phase_modulations_all
コード例 #2
0
    def __init__(self,
                 obj: Obj,
                 probe: Probe3D,
                 scan_grid: BraggPtychoGrid,
                 exit_wave_axis: str = "y",
                 upsampling_factor: int = 1,
                 setup_second_order: bool = False,
                 dtype: str = "float32"):
        if scan_grid.full_field_probe:
            e = ValueError(
                "Full field probe not supported for Bragg ptychography.")
            logger.error(e)
            raise e
        if scan_grid.grid2d_axes != ("y", "z"):
            e = ValueError(
                "Only supports the case where the ptychographic scan is on the yz-plane."
            )
            logger.error(e)
            raise e
        if exit_wave_axis != 'y':
            e = ValueError(
                "Only supports the case where the exit waves are output along the y-direction."
            )
            logger.error(e)
            raise e

        super().__init__(obj, probe, scan_grid, upsampling_factor,
                         setup_second_order, dtype)
        logger.info("Creating the phase modulations for the scan angles.")

        with tf.device("/gpu:0"):
            self._probe_phase_modulations_all_t = tf.constant(
                self._getProbePhaseModulationsStack(), dtype='complex64')
            self._full_rc_positions_indices_t = tf.constant(
                scan_grid.full_rc_positions_indices, dtype='int64')
コード例 #3
0
ファイル: simulation.py プロジェクト: ni-chen/ptychoSampling
    def createObj(obj_params_dict: dict,
                  det_3d_shape: Tuple[int, int, int],
                  obj_pixel_size: Tuple[float, float, float],
                  upsampling_factor: float = 1.0) -> Obj:
        logger.info("Creating new crystal cell.")
        while True:
            obj_crystal = Simulated3DCrystalCell(**obj_params_dict)
            if det_3d_shape[
                    1] * upsampling_factor >= obj_crystal.bordered_array.shape[
                        1]:
                break
            else:
                logger.warning(
                    "Generated object width is larger than detector x-width. Trying again."
                )

        logger.info(
            'Adding x and z borders to crystal cell based on detector parameters.'
        )

        # Ensuring that the number of pixels in the x dimension matches that in the detector.
        padx = (det_3d_shape[1] - obj_crystal.bordered_array.shape[1]) // 2

        # Calculating the padding needed to accomodate all feasible probe translations (w probe-obj interaction)
        padz, _ = Simulation.calculateObjBorderAfterRotation(
            0, det_3d_shape[2], obj_crystal.bordered_array.shape[1], 0)

        pad_shape = (0, padx, padz)
        obj = CustomObjFromArray(array=obj_crystal.bordered_array,
                                 border_shape=np.vstack(
                                     (pad_shape, pad_shape)).T,
                                 border_const=0,
                                 pixel_size=obj_pixel_size)
        return obj
コード例 #4
0
    def getFlopsPerIter(self) -> int:

        if self.n_validation > 0:
            e = RuntimeWarning("n_validation > 0 can give misleading flops.")
            logger.warning(e)

        if hasattr(self, "session"):
            e = RuntimeWarning(
                "Calculating computational cost after previously initializing the static graph can "
                +
                "include inessential calculations (like training and/or validation loss value) "
                + "and thus give misleading results.")
            logger.warning(e)

        else:
            with self.graph.as_default():
                config = tf.ConfigProto()
                config.gpu_options.allow_growth = True
                logger.info("Initializing the session.")
                self.session = tf.Session(config=config)
                self.session.run(tf.global_variables_initializer())

        for opt in self.optimizers:
            self.session.run(opt.minimize_op)

        total_flops = getComputationalCostInFlops(self.graph)
        return total_flops
コード例 #5
0
ファイル: datalogs_t.py プロジェクト: ni-chen/ptychoSampling
 def finalize(self):
     columns = []
     for item in self._datalog_items:
         columns.append(item.title)
     logger.info("Initializing the log outputs...")
     self.dataframe = DataFrame(columns=columns, dtype='float32')
     self.dataframe.loc[0] = np.nan
コード例 #6
0
ファイル: recons.py プロジェクト: saugatkandel/ptychoSampling
    def _initGraph(self):
        self.graph = tf.Graph()
        with self.graph.as_default():
            with tf.device("/gpu:0"):
                self._amplitudes_t = tf.constant(self.amplitudes,
                                                 dtype=self.dtype)
                if self.intensities_mask is not None:
                    self._data_mask_t = tf.constant(self.intensities_mask,
                                                    dtype=tf.bool)

            logger.info('creating batches...')
            self._createDataBatches()
コード例 #7
0
ファイル: simulation.py プロジェクト: ni-chen/ptychoSampling
    def createProbe3D(probe_params_dict: dict, wavelength: float,
                      rotate_angle: float, probe_xz_shape: Tuple[int, int],
                      obj_pixel_size: Tuple[int, int, int]) -> Probe:
        logger.info("Creating new guassian, speckled, 2d probe.")
        probe_yz = GaussianSpeckledProbe(wavelength=wavelength,
                                         **probe_params_dict)
        ny = probe_yz.shape[0]

        nx, nz = probe_xz_shape
        # overdoing the repeat and interpolation just for safety
        logger.info(
            "Rotating and interpolating the 2d probe to generate the 3d probe."
        )
        probe_yz_stack = np.repeat(probe_yz.wavefront.ifftshift[:, None, :],
                                   nx * 2,
                                   axis=1)

        rdeg = rotate_angle * 180 / np.pi
        rotated_real = ndimage.rotate(np.real(probe_yz_stack),
                                      rdeg,
                                      axes=(0, 1),
                                      mode='constant',
                                      order=1)
        rotated_imag = ndimage.rotate(np.imag(probe_yz_stack),
                                      rdeg,
                                      axes=(0, 1),
                                      mode='constant',
                                      order=1)
        rotated = rotated_real + 1j * rotated_imag

        # Calculating the number of pixels required to capture the y-structure of the rotated probe
        rny = ny / np.cos(rotate_angle) // 1
        rshape = np.array([rny, nx, nz]).astype('int')

        # Calculating the extent of the probe (relative to the center of rotation) in the x and y dimensions and
        # also ensuring that the probe array has an even number of pixels.
        # Adding any required padding so that the z dimension of the probe matches up with the detector.
        a = (np.array(rotated.shape) - rshape) // 2
        b = ((np.array(rotated.shape) - rshape) // 2 +
             (np.array(rotated.shape) - rshape) % 2)
        z_pad = (rshape[2] - rotated.shape[2]) // 2

        rotated_centered = rotated[a[0]:-b[0], a[1]:-b[1]]
        rotated_centered = np.pad(rotated_centered,
                                  [[0, 0], [0, 0], [z_pad, z_pad]],
                                  mode='constant',
                                  constant_values=0)
        probe_3d = CustomProbe3DFromArray(array=rotated_centered,
                                          wavelength=wavelength,
                                          pixel_size=obj_pixel_size)
        return probe_3d
コード例 #8
0
    def __init__(self,
                 *args: int,
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        logger.info('attaching fwd model...')
        self.attachForwardModel("farfield")
        logger.info('creating loss fn...')
        self.attachLossFunction("least_squared")
        logger.info('creating optimizers...')
        self.attachOptimizerPerVariable(
            "obj",
            optimizer_type="adam",
            optimizer_init_args={"learning_rate": 1e-2})
        self.attachOptimizerPerVariable(
            "probe",
            optimizer_type="adam",
            optimizer_init_args={"learning_rate": 1e-1},
            initial_update_delay=0)

        if obj_array_true is not None:
            self.addCustomMetricToDataLog(
                title="obj_error",
                tensor=self.fwd_model.obj_cmplx_t,
                log_epoch_frequency=10,
                registration_ground_truth=obj_array_true)
        if probe_wavefront_true is not None:
            self.addCustomMetricToDataLog(
                title="probe_error",
                tensor=self.fwd_model.probe_cmplx_t,
                log_epoch_frequency=10,
                registration_ground_truth=probe_wavefront_true)
コード例 #9
0
ファイル: recons.py プロジェクト: saugatkandel/ptychoSampling
    def __init__(self,
                 *args: int,
                 loss_type: str = "least_squared",
                 obj_array_true: np.ndarray = None,
                 learning_rate_obj: float = 1e-2,
                 registration_log_frequency=10,
                 **kwargs):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        logger.info('attaching fwd model...')
        self.attachForwardModel("bragg")
        logger.info('creating loss fn...')
        self.attachLossFunction(loss_type)
        logger.info('creating optimizers...')
        self.attachTensorflowOptimizerForVariable(
            "obj",
            optimizer_type="adam",
            optimizer_init_args={"learning_rate": learning_rate_obj})

        if obj_array_true is not None:
            self.addCustomMetricToDataLog(
                title="obj_error",
                tensor=self.fwd_model.obj_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration_ground_truth=obj_array_true)
コード例 #10
0
    def getFlopsPerIter(self) -> dict:

        if self.n_validation > 0:
            e = RuntimeWarning("n_validation > 0 can give misleading flops.")
            logger.warning(e)

        if hasattr(self, "session"):
            e = RuntimeWarning(
                "Calculating computational cost after previously initializing the static graph can "
                +
                "include inessential calculations (like training and/or validation loss value) "
                + "and thus give misleading results.")
            logger.warning(e)

        else:
            with self.graph.as_default():
                config = tf.ConfigProto()
                config.gpu_options.allow_growth = True
                logger.info("Initializing the session.")
                self.session = tf.Session(config=config)
                self.session.run(tf.global_variables_initializer())

        for opt in self.optimizers:
            self.session.run(opt.minimize_op)

        total_flops = getComputationalCostInFlops(self.graph)

        ls_flops = getComputationalCostInFlops(self.graph,
                                               keywords=[
                                                   ("opt_minimize_step",
                                                    "backtracking_linesearch")
                                               ],
                                               exclude_keywords=False)

        opt_only_flops = getComputationalCostInFlops(self.graph,
                                                     keywords=["opt"],
                                                     exclude_keywords=False)

        flops_without_ls = total_flops - ls_flops
        d = {
            "total_flops": total_flops,
            "obj_ls_flops": ls_flops,
            "probe_ls_flops": 0,
            "obj_only_flops": opt_only_flops,
            "probe_only_flops": 0,
            "flops_without_ls": flops_without_ls
        }
        return d
コード例 #11
0
ファイル: recons.py プロジェクト: saugatkandel/ptychoSampling
    def __init__(self,
                 obj: ptychoSampling.obj.Obj,
                 probe: ptychoSampling.probe.Probe,
                 grid: ptychoSampling.grid.ScanGrid,
                 intensities: np.ndarray,
                 intensities_mask: np.ndarray = None,
                 n_validation: int = 0,
                 training_batch_size: int = 0,
                 validation_batch_size: int = 0,
                 background_level: float = 0.,
                 dtype: str = 'float32'):
        self.obj = copy.deepcopy(obj)
        self.probe = copy.deepcopy(probe)
        self.grid = copy.deepcopy(grid)
        self.amplitudes = intensities**0.5
        self.background_level = background_level
        if intensities_mask is not None:
            if intensities_mask.dtype != np.bool:
                raise ValueError("Mask supplied should be a boolean array.")
        self.intensities_mask = intensities_mask
        self.dtype = dtype
        if self.dtype == "float32":
            self._eps = 1e-8
        elif self.dtype == "float64":
            self._eps = 1e-8
        else:
            raise ValueError

        self._splitTrainingValidationData(n_validation, training_batch_size,
                                          validation_batch_size)
        self._initGraph()

        logger.info('creating pandas log...')
        self.iteration = 0
        self.datalog = DataLogs()
        self._default_log_items = {"epoch": None, "train_loss": None}

        for key in self._default_log_items:
            self.datalog.addSimpleMetric(key)

        if self.n_validation > 0:
            self._validation_log_items = {
                "validation_loss": None,
                "validation_min": None,
                "patience": None
            }
            for key in self._validation_log_items:
                self.datalog.addSimpleMetric(key)
コード例 #12
0
ファイル: simulation.py プロジェクト: ni-chen/ptychoSampling
    def createScanGrid(scan_grid_params_dict: dict,
                       angle_grid_params_dict: dict,
                       bordered_obj_yz_shape: Tuple[int, int],
                       probe_yz_shape: Tuple[int, int]) -> BraggPtychoGrid:

        logger.info(
            "creating new 2d scan grid based on object and probe shapes.")
        scan_grid_2d = RectangleGrid(obj_w_border_shape=bordered_obj_yz_shape,
                                     probe_shape=probe_yz_shape,
                                     **scan_grid_params_dict)
        scan_grid_2d.checkOverlap()

        logger.info("Using created 2d scan grid to create full RC scan grid.")
        scan_grid = BraggPtychoGrid.fromPtychoScan2D(scan_grid_2d,
                                                     grid2d_axes=("y", "z"),
                                                     **angle_grid_params_dict)
        return scan_grid
コード例 #13
0
    def __init__(self,
                 obj: Obj,
                 probe: Probe,
                 scan_grid: ScanGrid,
                 upsampling_factor: int = 1):

        self.model_vars = {}

        with tf.device("/gpu:0"):
            self.obj_cmplx_t = self._addComplexVariable(obj.array, name="obj")
            self.obj_w_border_t = tf.pad(self.obj_cmplx_t, obj.border_shape, constant_values=obj.border_const)

            self.probe_cmplx_t = self._addComplexVariable(probe.wavefront, "probe")

            self.upsampling_factor = upsampling_factor
            logger.info("Creating obj views for the scan positions.")
            self._obj_views_all_t = self._getPtychoObjViewStack(obj, probe, scan_grid)
コード例 #14
0
    def _calculateGaussianFit(self) -> None:
        r"""Fit a 2d gaussian to the probe intensities.

        Performs a least-squares fit (using ``scipy.optimize.curve_fit``) to fit a 2d gaussian to the probe
        intensities. Uses the calculated gaussian spread to calculate the FWHM as well.

        See also
        --------
        gaussian_fit
        utils.generalized2dGaussian
        """
        logger.info(
            'Fitting a generalized 2d gaussian to the probe intensity.')
        from scipy.optimize import curve_fit

        nx = self.shape[-1]
        ny = self.shape[-2]
        #intensities = np.fft.ifftshift(np.abs(self.wavefront)**2)
        intensities = self.wavefront.fftshift.intensities
        y = np.arange(-ny // 2, ny // 2) * self.pixel_size[0]
        x = np.arange(-ny // 2, nx // 2) * self.pixel_size[1]
        yy, xx = np.meshgrid(y, x)
        xdata = np.stack((xx.flatten(), yy.flatten()), axis=1)
        bounds_min = [0, x[0], y[0], 0, 0, -np.pi / 4, 0]
        bounds_max = [
            intensities.sum(), x[-1], y[-1], x[-1] * 2, y[-1] * 2, np.pi / 4,
            intensities.max()
        ]
        popt, _ = curve_fit(utils.generalized2dGaussian,
                            xdata,
                            intensities.flatten(),
                            bounds=[bounds_min, bounds_max])
        amplitude, center_x, center_y, sigma_x, sigma_y, theta, offset = popt
        self._gaussian_fit = {
            "amplitude": amplitude,
            "center_x": center_x,
            "center_y": center_y,
            "sigma_x": sigma_x,
            "sigma_y": sigma_y,
            "theta": theta,
            "offset": offset
        }
        self._gaussian_fwhm = 2.355 * np.array((sigma_y, sigma_x))
コード例 #15
0
    def _getPtychoObjViewStack(
            self,
            obj_w_border_t: tf.Tensor,
            probe_cmplx_t: tf.Tensor,
            position_indices_t: tf.Tensor = None) -> tf.Tensor:
        """Precalculate the object positioning for each scan position.

        Assumes a small object that is translated within the dimensions of a full-field probe. For each scan
        position, we translate the object, then pad the object array to the size of the probe beam. For the padding,
        we assume free-space (transparent) propagation and use 1.0.

        In Tensorflow, performing the pad-and-stack procedure in the GPU for complex -valued arrays seems to be
        buggy. As a workaround, we separately pad-and-stack the real and imaginary parts of the object with 1.0 and
        0 respectively.

        Returns
        ----------
        obj_views : tensor(complex)
            Stack of tensors that correspond to the padded object at each object translation.
        """
        if position_indices_t is None:
            position_indices_t = tf.range(
                self._scan_grid.positions_pix.shape[0])

        logger.info("Creating obj views for the scan positions.")
        if not hasattr(self, "_obj_view_indices_t"):
            obj_view_indices = self._genViewIndices(
                self._scan_grid.positions_pix)
            self._obj_view_indices_t = tf.constant(obj_view_indices)

        batch_obj_view_indices_t = tf.gather(self._obj_view_indices_t,
                                             position_indices_t)
        batch_obj_views_t = tf.gather(tf.reshape(obj_w_border_t, [-1]),
                                      batch_obj_view_indices_t)

        #obj_view_indices = self._genViewIndices(scan_grid.positions_pix)
        #obj_view_indices_t = tf.constant(obj_view_indices, dtype='int64')
        #obj_views_t = tf.gather(tf.reshape(self.obj_w_border_t, [-1]), obj_view_indices_t)
        #obj_views_t = tf.reshape(obj_views_t,
        #                         (obj_view_indices.shape[0],
        #                          *(self.probe_cmplx_t.get_shape().as_list())))
        return batch_obj_views_t
コード例 #16
0
    def __init__(self,
                 obj: ptychoSampling.obj.Obj,
                 probe: ptychoSampling.probe.Probe,
                 grid: ptychoSampling.grid.ScanGrid,
                 intensities: np.ndarray,
                 n_validation: int = 0,
                 batch_size: int = 0):
        self.obj = copy.deepcopy(obj)
        self.probe = copy.deepcopy(probe)
        self.grid = copy.deepcopy(grid)
        self.amplitudes = intensities**0.5

        self.n_all = self.amplitudes.shape[0]
        self.n_validation = n_validation
        self.n_train = self.n_all - self.n_validation
        self.batch_size = batch_size

        self.graph = tf.Graph()
        with self.graph.as_default():
            with tf.device("/gpu:0"):
                self._amplitudes_t = tf.constant(self.amplitudes,
                                                 dtype=tf.float32)
            logger.info('creating batches...')
            self._createDataBatches()

        logger.info('creating log...')
        self.iteration = 0
        self.datalog = DataLogs()
        self._default_log_items = {"epoch": None, "train_loss": None}

        for key in self._default_log_items:
            self.datalog.addSimpleMetric(key)

        if self.n_validation > 0:
            self._validation_log_items = {
                "validation_loss": None,
                "validation_min": None,
                "patience": None
            }
            for key in self._validation_log_items:
                self.datalog.addSimpleMetric(key)
コード例 #17
0
ファイル: simulation.py プロジェクト: ni-chen/ptychoSampling
    def _calculateDiffractionPatterns(self):

        # Calculating the wave vectors
        self._calculatePhaseModulationsForRCAngles()

        intensities_all = []

        logger.info("Calculating the generated diffraction patterns.")
        for ia in range(self.scan_grid.n_rc_angles):
            for ib, (py, pz) in enumerate(self.scan_grid.positions_pix):
                obj_slice = self.obj.bordered_array[py:py +
                                                    self.probe_3d.shape[0], :,
                                                    pz:pz +
                                                    self.probe_3d.shape[2]]
                exit_wave = obj_slice * self.probe_3d.wavefront * self._phase_modulations_all[
                    ia]
                exit_wave_proj = np.sum(exit_wave, axis=0).fftshift
                intensities_all.append(exit_wave_proj.propFF().intensities)

        self.intensities = np.random.poisson(
            intensities_all) if self.poisson_noise else np.array(
                intensities_all)
コード例 #18
0
ファイル: recons.py プロジェクト: saugatkandel/ptychoSampling
 def finalizeSetup(self):
     self._checkAttr("optimizers", "finalize")
     logger.info("finalizing the data logger.")
     self.datalog.finalize()
     with self.graph.as_default():
         config = tf.ConfigProto()
         config.gpu_options.allow_growth = True
         logger.info("Initializing the session.")
         self.session = tf.Session(config=config)
         self.session.run(tf.global_variables_initializer())
     logger.info("Finalized setup.")
コード例 #19
0
ファイル: helper_new.py プロジェクト: ni-chen/ptychoSampling
    def __init__(self,
                 *args: int,
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 shuffle_order: list = None,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        logger.info('attaching fwd model...')
        self.attachForwardModel("farfield")
        logger.info('creating loss fn...')
        self.attachLossFunction("least_squared")
コード例 #20
0
    def __init__(self,
                 *args: int,
                 loss_type: str = "least_squared",
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 update_delay_probe: int = 0,
                 update_delay_obj: int = 0,
                 reconstruct_probe: bool = True,
                 registration_log_frequency: int = 10,
                 opt_init_extra_kwargs: dict = None,
                 obj_abs_proj: bool = True,
                 loss_init_extra_kwargs: dict = None,
                 r_factor_log: bool = True,
                 apply_precond: bool = False,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        if self.training_batch_size != self.n_train:
            e = ValueError(
                "Conjugate gradient reconstruction does not support minibatch reconstruction."
            )
            logger.error(e)
            raise e

        logger.info('attaching fwd model...')
        self._attachCustomForwardModel(JointFarfieldForwardModelT,
                                       obj_abs_proj=obj_abs_proj)

        logger.info('creating loss fn...')
        self.attachLossFunctionSecondOrder(
            loss_type, loss_init_extra_kwargs=loss_init_extra_kwargs)

        logger.info("creating optimizers...")
        if opt_init_extra_kwargs is None:
            opt_init_extra_kwargs = {}

        self._preconditioner = None
        if apply_precond:
            with self.graph.as_default():
                loss_data_type = self._loss_method.data_type
                hessian_t = (
                    tf.ones_like(self._batch_train_predictions_t) *
                    self._train_hessian_fn(self._batch_train_predictions_t))
                hessian_t = tf.reshape(hessian_t, [-1, *self.probe.shape])

                obj_scaling = self._getObjScaling(loss_data_type, hessian_t)
                probe_scaling = self._getProbeScaling(loss_data_type,
                                                      hessian_t)
                scaling_both = tf.concat((obj_scaling, probe_scaling), axis=0)
                scaling_both = tf.concat((scaling_both, scaling_both), axis=0)
                zero_condition = tf.less(scaling_both,
                                         1e-10 * tf.reduce_max(scaling_both))
                zero_case = tf.ones_like(scaling_both) / (
                    1e-8 * tf.reduce_max(scaling_both))
                self._preconditioner = tf.where(zero_condition, zero_case,
                                                1 / scaling_both)

        opt_init_args = {
            "input_var": self.fwd_model.joint_v,
            "predictions_fn": self._predictions_fn,
            "loss_fn": self._train_loss_fn,
            "name": "opt",
            "diag_precondition_t": self._preconditioner
        }
        opt_init_args.update(opt_init_extra_kwargs)
        self._attachCustomOptimizerForVariable(
            ConjugateGradientOptimizer, optimizer_init_args=opt_init_args)
        self.addCustomMetricToDataLog(
            title="ls_iters",
            tensor=self.optimizers[0]._optimizer._linesearch_steps,
            log_epoch_frequency=1)
        self.addCustomMetricToDataLog(
            title="alpha",
            tensor=self.optimizers[0]._optimizer._linesearch._alpha,
            log_epoch_frequency=1)

        if obj_array_true is not None:
            self.addCustomMetricToDataLog(
                title="obj_error",
                tensor=self.fwd_model.obj_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration_ground_truth=obj_array_true)
        if reconstruct_probe and (probe_wavefront_true is not None):
            self.addCustomMetricToDataLog(
                title="probe_error",
                tensor=self.fwd_model.probe_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration_ground_truth=probe_wavefront_true)
        self._addRFactorLog(r_factor_log, registration_log_frequency)
コード例 #21
0
    def __init__(
        self,
        *args: int,
        loss_type: str = "least_squared",
        obj_array_true: np.ndarray = None,
        probe_wavefront_true: np.ndarray = None,
        update_delay_probe: int = 0,
        update_frequency_probe: int = 1,
        update_frequency_obj: int = 1,
        reconstruct_probe: bool = True,
        registration_log_frequency: int = 1,
        opt_init_extra_kwargs: dict = None,
        obj_abs_proj: bool = True,
        loss_init_extra_kwargs: dict = None,
        r_factor_log: bool = True,
        #### These two parameters are experimental only. Not to be used in production simulations.
        # apply_diag_mu_scaling: bool = False, # Does not help
        apply_precond: bool = False,
        ###########################################################################################
        **kwargs: int):
        print("opt_init_extra_kwargs", opt_init_extra_kwargs)
        print("Loss init args", loss_init_extra_kwargs)
        print("update_delay_probe", update_delay_probe, "update_frequency",
              update_frequency_probe)
        print("update_frequency_obj", update_frequency_obj)

        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        logger.info('attaching fwd model...')
        self._attachCustomForwardModel(JointFarfieldForwardModelT,
                                       obj_abs_proj=obj_abs_proj)

        logger.info('creating loss fn...')
        self.attachLossFunctionSecondOrder(
            loss_type, loss_init_extra_kwargs=loss_init_extra_kwargs)

        logger.info("creating optimizers...")

        if opt_init_extra_kwargs is None:
            opt_init_extra_kwargs = {}

        opt_init_args = {
            "input_var": self.fwd_model.joint_v,
            "predictions_fn": self._predictions_fn,
            "loss_fn": self._train_loss_fn,
            "diag_hessian_fn": self._train_hessian_fn,
            "damping_factor": 10.0,
            "damping_update_frequency": 5,
            "damping_update_factor": 0.99,
            "name": "opt"
        }

        # Experimental only.########################################################################################
        if apply_precond:
            loss_data_type = self._loss_method.data_type

            with self.graph.as_default():
                hessian_t = (
                    tf.ones_like(self._batch_train_predictions_t) *
                    self._train_hessian_fn(self._batch_train_predictions_t))
                hessian_t = tf.reshape(hessian_t, [-1, *self.probe.shape])

                obj_scaling_t = self._getObjScaling(loss_data_type, hessian_t)
                probe_scaling_t = self._getProbeScaling(
                    loss_data_type, hessian_t)
                scaling_both = tf.concat((obj_scaling_t, probe_scaling_t),
                                         axis=0)
                self._joint_scaling_t = tf.concat((scaling_both, scaling_both),
                                                  axis=0)
            opt_init_args['diag_precond_t'] = self._joint_scaling_t
        #############################################################################################################

        opt_init_args.update(opt_init_extra_kwargs)
        self._attachCustomOptimizerForVariable(
            CurveballOptimizer, optimizer_init_args=opt_init_args)

        self.addCustomMetricToDataLog(
            title="mu",
            tensor=self.optimizers[0]._optimizer._damping_factor,
            log_epoch_frequency=1)

        if obj_array_true is not None:
            self.addCustomMetricToDataLog(
                title="obj_error",
                tensor=self.fwd_model.obj_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration_ground_truth=obj_array_true)
        if (probe_wavefront_true is not None):
            self.addCustomMetricToDataLog(
                title="probe_error",
                tensor=self.fwd_model.probe_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration_ground_truth=probe_wavefront_true)
        self._addRFactorLog(r_factor_log, registration_log_frequency)
コード例 #22
0
    def __init__(self,
                 *args: int,
                 loss_type: str = "least_squared",
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 max_cg_iter: int = 100,
                 min_cg_tol: float = 0.1,
                 registration_log_frequency: int = 10,
                 both_registration_nlse: bool = True,
                 opt_init_extra_kwargs: dict = None,
                 obj_abs_proj: bool = True,
                 obj_abs_max: float = 1.0,
                 probe_abs_max: float = None,
                 loss_init_extra_kwargs: dict = None,
                 r_factor_log: bool = True,
                 update_delay_probe: int = 0,
                 reconstruct_probe: bool = True,
                 apply_diag_mu_scaling: bool = True,
                 apply_precond: bool = False,
                 apply_precond_and_scaling: bool = False,
                 stochastic_diag_estimator_type: str = None,
                 stochastic_diag_estimator_iters: int = 1,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        if apply_precond_and_scaling:
            logger.info(
                "Overriding any set values for apply_precond and apply_diag_mu_scaling"
            )
            apply_precond = True
            apply_diag_mu_scaling = True

        if stochastic_diag_estimator_type is not None:
            if apply_diag_mu_scaling or apply_precond:
                e = ValueError(
                    "Cannot use analytical precond and diag ggn elems if stochastic calculation is enabled."
                )
                logger.error(e)
                raise e

        if self.training_batch_size != self.n_train:
            e = ValueError(
                "LMA reconstruction does not support minibatch reconstruction."
            )
            logger.error(e)
            raise e

        logger.info('attaching fwd model...')
        self._attachCustomForwardModel(JointFarfieldForwardModelT,
                                       obj_abs_proj=obj_abs_proj,
                                       obj_abs_max=obj_abs_max,
                                       probe_abs_max=probe_abs_max)

        logger.info('creating loss fn...')

        print('Loss init args', loss_init_extra_kwargs)
        self.attachLossFunctionSecondOrder(
            loss_type, loss_init_extra_kwargs=loss_init_extra_kwargs)

        logger.info("creating optimizer...")
        if opt_init_extra_kwargs is None:
            opt_init_extra_kwargs = {}

        if apply_diag_mu_scaling or apply_precond:

            with self.graph.as_default():

                loss_data_type = self._loss_method.data_type
                hessian_t = (
                    tf.ones_like(self._batch_train_predictions_t) *
                    self._train_hessian_fn(self._batch_train_predictions_t))
                hessian_t = tf.reshape(hessian_t, [-1, *self.probe.shape])
                print(hessian_t)
                self._obj_scaling = self._getObjScaling(
                    loss_data_type, hessian_t)
                self._probe_scaling = self._getProbeScaling(
                    loss_data_type, hessian_t)
                scaling_both = tf.concat(
                    (self._obj_scaling, self._probe_scaling), axis=0)

                if apply_diag_mu_scaling:
                    self._joint_scaling_t = tf.concat(
                        (scaling_both, scaling_both), axis=0)
                    optimizer = ScaledLMAOptimizer
                    opt_init_extra_kwargs[
                        'diag_mu_scaling_t'] = self._joint_scaling_t
                if apply_precond:
                    self._joint_precond_t = tf.concat(
                        (scaling_both, scaling_both), axis=0)
                    optimizer = PCGLMAOptimizer
                    opt_init_extra_kwargs[
                        'diag_precond_t'] = self._joint_precond_t
                    if apply_diag_mu_scaling:
                        optimizer = ScaledPCGLMAOptimizer
                #
        else:
            optimizer = LMAOptimizer

        opt_init_args = {
            "input_var": self.fwd_model.joint_v,
            "predictions_fn": self._predictions_fn,
            "loss_fn": self._train_loss_fn,
            "diag_hessian_fn": self._train_hessian_fn,
            "name": "opt",
            "max_cg_iter": max_cg_iter,
            "min_cg_tol": min_cg_tol,
            "stochastic_diag_estimator_type": stochastic_diag_estimator_type,
            "stochastic_diag_estimator_iters": stochastic_diag_estimator_iters,
            "assert_tolerances": False
        }
        opt_init_args.update(opt_init_extra_kwargs)

        self._attachCustomOptimizerForVariable(
            optimizer, optimizer_init_args=opt_init_args)
        self.addCustomMetricToDataLog(
            title="cg_iters",
            tensor=self.optimizers[0]._optimizer._total_cg_iterations,
            log_epoch_frequency=1)
        self.addCustomMetricToDataLog(
            title="ls_iters",
            tensor=self.optimizers[0]._optimizer._total_proj_ls_iterations,
            log_epoch_frequency=1)
        self.addCustomMetricToDataLog(title="proj_iters",
                                      tensor=self.optimizers[0]._optimizer.
                                      _projected_gradient_iterations,
                                      log_epoch_frequency=1)
        self.addCustomMetricToDataLog(title="mu",
                                      tensor=self.optimizers[0]._optimizer._mu,
                                      log_epoch_frequency=1)

        self._registration_log_frequency = registration_log_frequency
        if obj_array_true is not None:
            self._setObjRegistration(
                obj_array_true, both_registration_nlse=both_registration_nlse)
        if probe_wavefront_true is not None:
            self._setProbeRegistration(
                probe_wavefront_true,
                both_registration_nlse=both_registration_nlse)
        self._addRFactorLog(r_factor_log, registration_log_frequency)
コード例 #23
0
    def __init__(self, *args: int,
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 obj_abs_proj: bool = True,
                 update_delay_probe: int = 0,
                 update_delay_obj: int = 0,
                 update_delay_aux: int = 0,
                 update_frequency_probe: int = 1,
                 update_frequency_obj: int = 1,
                 update_frequency_aux: int = 1,
                 reconstruct_probe: bool = True,
                 registration_log_frequency: int = 10,
                 r_factor_log: bool = True,
                 loss_type: str = "gaussian",
                 learning_rate_scaling_probe: float = 1.0,
                 learning_rate_scaling_obj: float = 1.0,
                 loss_init_extra_kwargs: dict = None,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)
        if self.training_batch_size != self.n_train:
            e = ValueError("PALM reconstruction does not support minibatch reconstruction.")
            logger.error(e)
            raise e

        if loss_type != "gaussian":
            e = ValueError("PALM reconstruction does not support other loss functions.")
            logger.error(e)
            raise e

        logger.info('attaching fwd model...')
        self._attachCustomForwardModel(FarfieldPALMForwardModel, obj_abs_proj=obj_abs_proj)

        logger.info('creating loss fn...')
        self._attachCustomLossFunction(AuxLossFunctionT, loss_init_extra_kwargs)

        logger.info("create learning rates")
        self._learning_rate_scaling_probe = learning_rate_scaling_probe
        self._learning_rate_scaling_obj = learning_rate_scaling_obj
        self._lr_obj = self._getObjLearningRate() * learning_rate_scaling_obj
        self._lr_probe = self._getProbeLearningRate() * learning_rate_scaling_probe

        logger.info('creating optimizers...')
        aux_new_t = self._getAuxUpdate()
        self._attachCustomOptimizerForVariable(AuxAssignOptimizer,
                                               optimizer_init_args={"aux_v": self.fwd_model.aux_v,
                                                                    "aux_new_t": aux_new_t},
                                               initial_update_delay=update_delay_aux,
                                               update_frequency=update_frequency_aux)

        update_delay_obj = update_delay_obj if update_delay_obj is not None else 0
        self.attachTensorflowOptimizerForVariable("obj",
                                                  optimizer_type="gradient",
                                                  optimizer_init_args={"learning_rate": self._lr_obj},
                                                  initial_update_delay=update_delay_obj,
                                                  update_frequency=update_frequency_obj)

        update_delay_probe = update_delay_probe if update_delay_probe is not None else 0
        if reconstruct_probe:
            self.attachTensorflowOptimizerForVariable("probe",
                                                      optimizer_type="gradient",
                                                      optimizer_init_args={"learning_rate": self._lr_probe},
                                                      initial_update_delay=update_delay_probe,
                                                      update_frequency=update_frequency_probe)

        if obj_array_true is not None:
            self.addCustomMetricToDataLog(title="obj_error",
                                          tensor=self.fwd_model.obj_cmplx_t,
                                          log_epoch_frequency=registration_log_frequency,
                                          registration_ground_truth=obj_array_true)
        if reconstruct_probe and (probe_wavefront_true is not None):
            self.addCustomMetricToDataLog(title="probe_error",
                                          tensor=self.fwd_model.probe_cmplx_t,
                                          log_epoch_frequency=registration_log_frequency,
                                          registration_ground_truth=probe_wavefront_true)
        self._addRFactorLog(r_factor_log, 1)
コード例 #24
0
ファイル: recons.py プロジェクト: saugatkandel/ptychoSampling
    def __init__(self,
                 *args: int,
                 loss_type: str = "least_squared",
                 obj_array_true: np.ndarray = None,
                 probe_wavefront_true: np.ndarray = None,
                 r_factor_log: bool = False,
                 learning_rate_obj: float = 1e-2,
                 update_delay_obj: int = 0,
                 update_delay_probe: int = 0,
                 learning_rate_probe: float = 1e-1,
                 reconstruct_probe: bool = True,
                 registration_log_frequency: int = 10,
                 both_registration_nlse: bool = True,
                 obj_abs_proj: bool = True,
                 loss_init_extra_kwargs: dict = None,
                 **kwargs: int):
        logger.info('initializing...')
        super().__init__(*args, **kwargs)

        logger.info('attaching fwd model...')
        self.attachForwardModel("farfield", obj_abs_proj=obj_abs_proj)

        logger.info('creating loss fn...')
        self.attachLossFunction(loss_type,
                                loss_init_extra_kwargs=loss_init_extra_kwargs)
        logger.info('creating optimizers...')
        self.attachTensorflowOptimizerForVariable(
            "obj",
            optimizer_type="adam",
            optimizer_init_args={"learning_rate": learning_rate_obj},
            initial_update_delay=update_delay_obj)

        if reconstruct_probe:
            self.attachTensorflowOptimizerForVariable(
                "probe",
                optimizer_type="adam",
                optimizer_init_args={"learning_rate": learning_rate_probe},
                initial_update_delay=update_delay_probe)

        if obj_array_true is not None:
            self.addCustomMetricToDataLog(
                title="obj_error",
                tensor=self.fwd_model.obj_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration=True,
                registration_ground_truth=obj_array_true)
            if both_registration_nlse:
                self.addCustomMetricToDataLog(
                    title="obj_nlse",
                    tensor=self.fwd_model.obj_cmplx_t,
                    log_epoch_frequency=registration_log_frequency,
                    registration=False,
                    normalized_lse=True,
                    registration_ground_truth=obj_array_true)
        if reconstruct_probe and (probe_wavefront_true is not None):
            self.addCustomMetricToDataLog(
                title="probe_error",
                tensor=self.fwd_model.probe_cmplx_t,
                log_epoch_frequency=registration_log_frequency,
                registration_ground_truth=probe_wavefront_true)
            if both_registration_nlse:
                self.addCustomMetricToDataLog(
                    title="probe_nlse",
                    tensor=self.fwd_model.probe_cmplx_t,
                    log_epoch_frequency=registration_log_frequency,
                    registration=False,
                    normalized_lse=True,
                    registration_ground_truth=probe_wavefront_true)
        self._addRFactorLog(r_factor_log, registration_log_frequency)
コード例 #25
0
ファイル: simulation.py プロジェクト: ni-chen/ptychoSampling
    def __init__(self,
                 wavelength: float = 1.5e-10,
                 obj: Obj = None,
                 probe_3d: Probe3D = None,
                 scan_grid: BraggPtychoGrid = None,
                 detector: Detector = None,
                 poisson_noise: bool = True,
                 upsampling_factor: int = 1) -> None:
        self.wavelength = wavelength
        self.upsampling_factor = upsampling_factor
        self.poisson_noise = poisson_noise

        two_theta = scan_grid.two_theta if scan_grid is not None else AngleGridParams(
        ).two_theta

        if detector is not None:
            logger.info("Using supplied detector info.")
            self.detector = detector
        else:
            logger.info("Creating new detector.")
            self.detector = Detector(**dt.asdict(DetectorParams()))
        det_3d_shape = (1, *self.detector.shape)

        obj_xz_nyquist_support = self.wavelength * self.detector.obj_dist / np.array(
            self.detector.pixel_size)
        obj_xz_pixel_size = obj_xz_nyquist_support / (
            np.array(self.detector.shape) * self.upsampling_factor)

        # The y pixel size is very ad-hoc
        obj_y_pixel_size = obj_xz_nyquist_support[1] / self.detector.shape[1]
        obj_pixel_size = (obj_y_pixel_size, *obj_xz_pixel_size)

        if obj is not None:
            logger.info("Using supplied object.")
            self.obj = obj
        else:
            self.obj = self.createObj(dt.asdict(ObjParams()), det_3d_shape,
                                      obj_pixel_size, upsampling_factor)

        probe_xz_shape = np.array(self.detector.shape) * self.upsampling_factor

        if probe_3d is not None:
            logger.info("Using supplied 3d probe values.")
            self.probe_3d = probe_3d
            # Note that the probe y shape cannot be determined from the other supplied parameters (I think).
            if (np.any(probe_3d.shape[1:] != probe_xz_shape)
                    or (probe_3d.wavelength != self.wavelength)
                    or np.any(probe_3d.pixel_size != obj_pixel_size)):
                e = ValueError(
                    f"Mismatch between the supplied probe and the supplied scan parameters."
                )
                logger.error(e)
                raise e
        else:
            self.probe_3d = self.createProbe3D(dt.asdict(Probe2DParams()),
                                               wavelength,
                                               np.pi / 2 - two_theta,
                                               probe_xz_shape, obj_pixel_size)

        rotate_angle = np.pi / 2 - two_theta
        probe_y_pix_before_rotation = self.probe_3d.shape[0] * np.cos(
            rotate_angle) // 1
        pady, bordered_obj_ypix = self.calculateObjBorderAfterRotation(
            rotate_angle, probe_y_pix_before_rotation, self.obj.shape[0],
            self.obj.shape[1])

        if self.obj.bordered_array.shape[0] < bordered_obj_ypix:
            logger.warning(
                "Adding zero padding to the object in the y-direction so that the overall object y-width "
                + "covers the entirety of the feasible probe positions.")
            self.obj.border_shape = ((pady, pady), self.obj.border_shape[1],
                                     self.obj.border_shape[2])
        if scan_grid is not None:
            logger.info("Using supplied scan grid.")
            self.scan_grid = scan_grid
        else:
            elems_yz = lambda tup: np.array(tup)[[0, 2]]
            scan_grid_params_dict = dt.asdict(
                Scan2DGridParams(obj_pixel_size=elems_yz(self.obj.pixel_size)))
            self.scan_grid = self.createScanGrid(
                scan_grid_params_dict, dt.asdict(AngleGridParams()),
                elems_yz(self.obj.bordered_array.shape),
                elems_yz(self.probe_3d.shape))
        self._calculateDiffractionPatterns()