Beispiel #1
0
    def predict_second_order(self,
                             joint_v: tf.Variable,
                             position_indices_t: tf.Tensor,
                             scope_name: str = "") -> tf.Tensor:

        scope_name = scope_name + "_predict" if scope_name else "predict"
        ndiffs = position_indices_t.get_shape().as_list()[0]

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name="scope")

            joint_t = self._reshapeComplexVariable(joint_v,
                                                   self._joint_array_shape)

            obj_t_flat = joint_t[:self._obj.array.size]
            probe_t_flat = joint_t[self._obj.array.size:]
            obj_cmplx_t = tf.reshape(obj_t_flat, self._obj.shape)
            probe_cmplx_t = tf.reshape(probe_t_flat, self._probe.shape)

            obj_w_border_t = tf.pad(obj_cmplx_t,
                                    self._obj.border_shape,
                                    constant_values=self._obj.border_const)

            batch_obj_views_t = self._getPtychoObjViewStack(
                obj_w_border_t, probe_cmplx_t, position_indices_t)
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * probe_cmplx_t
            out_wavefronts_t = propFF_t(exit_waves_t)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1])
Beispiel #2
0
    def predict_second_order(self,
                             obj_v: tf.Variable,
                             probe_v: tf.Variable,
                             position_indices_t: tf.Tensor,
                             scope_name: str = "") -> tf.Tensor:

        scope_name = scope_name + "_predict" if scope_name else "predict"
        ndiffs = position_indices_t.get_shape().as_list()[0]

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype='float32', name="scope")

            obj_cmplx_t = self._reshapeComplexVariable(obj_v, self._obj.shape)
            obj_w_border_t = tf.pad(obj_cmplx_t,
                                    self._obj.border_shape,
                                    constant_values=self._obj.border_const)

            probe_cmplx_t = self._reshapeComplexVariable(
                probe_v, self._probe.shape)

            batch_obj_views_t = self._getPtychoObjViewStack(
                obj_w_border_t, probe_cmplx_t, position_indices_t)
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * probe_cmplx_t
            farfield_waves_t = propFF_t(exit_waves_t)
            return tf.reshape(
                tf.stack(
                    [tf.real(farfield_waves_t),
                     tf.imag(farfield_waves_t)]), [-1])
Beispiel #3
0
    def _getProbeScaling(self, loss_data_type: str,
                         hessian_t: tf.Tensor) -> tf.Tensor:
        with self.graph.as_default():
            with tf.name_scope('probe_scaling'):
                if loss_data_type == "amplitude":
                    weights = tf.ones_like(hessian_t)
                elif loss_data_type == "intensity":
                    weights = tf.reshape(self._batch_train_predictions_t,
                                         [-1, *self.probe.shape])

                weights = tf.reduce_mean(weights * hessian_t, axis=(1, 2))
                batch_obj_views = tf.gather(self.fwd_model._obj_views_all_t,
                                            self._batch_train_input_v)

                weights = 0.5 * tf.abs(
                    fftshift_t(batch_obj_views))**2 * weights[:, None, None]
                t = tf.reduce_sum(weights, axis=0)
                #t = tf.reduce_sum(tf.abs(fftshift_t(batch_obj_views)) ** 2, axis=0) * 0.5
                # zero_condition = tf.less(t, 1e-10 * tf.reduce_max(t))
                # zero_case = tf.ones_like(t) / (1e-6 * tf.reduce_max(t))
                # H = tf.where(zero_condition, zero_case, 1 / t)

                H_reshaped = tf.reshape(t, [-1])

            return H_reshaped
Beispiel #4
0
    def predict(self,
                position_indices_t: tf.Tensor,
                scope_name: str = "") -> tf.Tensor:
        ndiffs = position_indices_t.get_shape().as_list()[0]
        scope_name = scope_name + "_predict" if scope_name else "predict"

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name=scope)

            batch_rc_positions_indices = tf.gather(
                self._full_rc_positions_indices_t, position_indices_t)
            batch_obj_views_t = tf.gather(self._obj_views_all_t,
                                          batch_rc_positions_indices[:, 1])
            batch_phase_modulations_t = tf.gather(
                self._probe_phase_modulations_all_t,
                batch_rc_positions_indices[:, 0])

            batch_obj_views_t = batch_obj_views_t
            exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t
            exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t,
                                                         axis=-3))

            out_wavefronts_t = propFF_t(exit_waves_proj_t)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1], name=scope)
Beispiel #5
0
    def predict_second_order(self,
                             obj_v: tf.Variable,
                             probe_v: tf.Variable,
                             position_indices_t: tf.Tensor,
                             scope_name: str = "") -> tf.Tensor:
        ndiffs = position_indices_t.get_shape().as_list()[0]
        scope_name = scope_name + "_predict" if scope_name else "predict"

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name=scope)

            obj_cmplx_t = self._reshapeComplexVariable(obj_v, self._obj.shape)
            obj_w_border_t = tf.pad(obj_cmplx_t,
                                    self._obj.border_shape,
                                    constant_values=self._obj.border_const)

            probe_cmplx_t = self._reshapeComplexVariable(
                probe_v, self._probe.shape)

            batch_obj_views_t = self._getPtychoObjViewStack(
                obj_w_border_t, probe_cmplx_t, position_indices_t)
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * probe_cmplx_t

            out_wavefronts_t = propTF_t(
                exit_waves_t,
                reuse_transfer_function=True,
                transfer_function=self._transfer_function)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1], name=scope)
    def _getObjLearningRate(self) -> tf.Tensor:
        with self.graph.as_default():
            probe_sq = tf.abs(fftshift_t(self.fwd_model.probe_cmplx_t)) ** 2
            batch_obj_view_indices = tf.gather(self.fwd_model._obj_view_indices_t, self._batch_train_input_v)
            batch_obj_view_indices = tf.unstack(batch_obj_view_indices)
            size = self.obj.bordered_array.size

            tf_mat = tf.zeros(size, dtype=tf.float32)
            for b in batch_obj_view_indices:
                mat_this = tf.scatter_nd(indices=tf.reshape(b, [-1, 1]),
                                        shape=[size],
                                        updates=tf.reshape(probe_sq, [-1]))
                tf_mat = tf_mat + mat_this
            return 1.0 / tf.reduce_max(tf_mat)
Beispiel #7
0
    def predict(self, position_indices_t: tf.Tensor):

        ndiffs = position_indices_t.get_shape().as_list()[0]

        if ndiffs == 0:
            return tf.zeros(shape=[], dtype='float32')

        batch_obj_views_t = tf.gather(self._obj_views_all_t, position_indices_t)
        batch_obj_views_t = fftshift_t(batch_obj_views_t)
        exit_waves_t = batch_obj_views_t * self.probe_cmplx_t
        out_wavefronts_t = propFF_t(exit_waves_t)
        amplitudes_t = tf.abs(out_wavefronts_t)
        if self.upsampling_factor > 1:
            amplitudes_t = self._downsample(amplitudes_t)
        return amplitudes_t
Beispiel #8
0
    def predict(self,
                position_indices_t: tf.Tensor,
                scope_name: str = "") -> tf.Tensor:

        scope_name = scope_name + "_predict" if scope_name else "predict"
        ndiffs = position_indices_t.get_shape().as_list()[0]

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype='complex64', name=scope)

            batch_obj_views_t = tf.gather(self._obj_views_all_t,
                                          position_indices_t)
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * self.probe_cmplx_t

            farfield_waves_t = propFF_t(exit_waves_t)
            return farfield_waves_t  #tf.reshape(tf.stack([tf.real(farfield_waves_t), tf.imag(farfield_waves_t)]), [-1])
Beispiel #9
0
    def predict(self, position_indices_t: tf.Tensor):
        ndiffs = position_indices_t.get_shape().as_list()[0]

        if ndiffs == 0:
            return tf.zeros(shape=[], dtype='float32')

        batch_rc_positions_indices = tf.gather(self._full_rc_positions_indices_t, position_indices_t)
        batch_obj_views_t = tf.gather(self._obj_views_all_t, batch_rc_positions_indices[:, 1])
        batch_phase_modulations_t = tf.gather(self._probe_phase_modulations_all_t, batch_rc_positions_indices[:, 0])

        batch_obj_views_t = batch_obj_views_t
        exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t
        exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t, axis=-3))

        out_wavefronts_t = propFF_t(exit_waves_proj_t)
        amplitudes_t = tf.abs(out_wavefronts_t)
        if self.upsampling_factor > 1:
            amplitudes_t = self._downsample(amplitudes_t)
        return amplitudes_t
Beispiel #10
0
    def predict(self,
                position_indices_t: tf.Tensor,
                scope_name: str = "") -> tf.Tensor:

        scope_name = scope_name + "_predict" if scope_name else "predict"
        ndiffs = position_indices_t.get_shape().as_list()[0]

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name=scope)

            batch_obj_views_t = tf.gather(self._obj_views_all_t,
                                          position_indices_t,
                                          name="test_gather")
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * self.probe_cmplx_t
            out_wavefronts_t = propFF_t(exit_waves_t)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1], name=scope)
Beispiel #11
0
    def predict_second_order(self,
                             obj_v: tf.Variable,
                             probe_v: tf.Variable,
                             position_indices_t: tf.Tensor,
                             scope_name: str = "") -> tf.Tensor:
        ndiffs = position_indices_t.get_shape().as_list()[0]
        scope_name = scope_name + "_predict" if scope_name else "predict"

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name=scope)
            obj_cmplx_t = self._reshapeComplexVariable(obj_v, self._obj.shape)
            obj_w_border_t = tf.pad(obj_cmplx_t,
                                    self._obj.border_shape,
                                    constant_values=self._obj.border_const)

            probe_cmplx_t = self._reshapeComplexVariable(
                probe_v, self._probe.shape)

            batch_rc_positions_indices = tf.gather(
                self._full_rc_positions_indices_t, position_indices_t)
            batch_obj_views_t = self._getPtychoObjViewStack(
                obj_w_border_t, probe_cmplx_t,
                batch_rc_positions_indices[:, 1])  #position_indices_t)

            batch_phase_modulations_t = tf.gather(
                self._probe_phase_modulations_all_t,
                batch_rc_positions_indices[:, 0])

            batch_obj_views_t = batch_obj_views_t
            exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t
            exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t,
                                                         axis=-3))

            out_wavefronts_t = propFF_t(exit_waves_proj_t)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1], name=scope)
Beispiel #12
0
    def _getObjScaling(self, loss_data_type: str,
                       hessian_t: tf.Tensor) -> tf.Tensor:
        with self.graph.as_default():
            with tf.name_scope('obj_scaling'):
                if loss_data_type == "amplitude":
                    weights = tf.ones_like(hessian_t)
                elif loss_data_type == "intensity":
                    weights = tf.reshape(self._batch_train_predictions_t,
                                         [-1, *self.probe.shape])

                weights = tf.reduce_mean(weights * hessian_t, axis=(1, 2))
                probe_abs_sq = fftshift_t(tf.abs(
                    self.fwd_model.probe_cmplx_t))**2

                #weights_this = fftshift_t(tf.abs(self.fwd_model.probe_cmplx_t)) ** 2 * 0.5
                batch_obj_view_indices = tf.gather(
                    self.fwd_model._obj_view_indices_t,
                    self._batch_train_input_v)
                batch_obj_view_indices = tf.unstack(batch_obj_view_indices)
                size = self.obj.bordered_array.size

                tf_mat = tf.zeros(size, dtype=self.dtype)
                for i, b in enumerate(batch_obj_view_indices):
                    weights_this = 0.5 * probe_abs_sq * weights[i]
                    mat_this = tf.scatter_nd(indices=tf.reshape(b, [-1, 1]),
                                             shape=[size],
                                             updates=tf.reshape(
                                                 weights_this, [-1]))
                    tf_mat = tf_mat + mat_this
                # zero_condition = tf.less(tf_mat, 1e-10 * tf.reduce_max(tf_mat))
                # zero_case = tf.ones_like(tf_mat) * (1e-6 * tf.reduce_max(tf_mat))
                # H = tf.where(zero_condition, zero_case, tf_mat)
                H = tf_mat
                H_reshaped = tf.reshape(H, self.obj.bordered_array.shape)
                (s1, s2), (s3, s4) = self.obj.border_shape
                H_trunc = tf.reshape(H_reshaped[s1:-s2, s3:-s4], [-1])
            return H_trunc
 def _getProbeLearningRate(self) -> tf.Tensor:
     with self.graph.as_default():
         batch_obj_views = tf.gather(self.fwd_model._obj_views_all_t, self._batch_train_input_v)
         return 1.0 / tf.reduce_max(tf.reduce_sum(tf.abs(fftshift_t(batch_obj_views)) ** 2, axis=0))