예제 #1
0
    def predict_second_order(self,
                             obj_v: tf.Variable,
                             probe_v: tf.Variable,
                             position_indices_t: tf.Tensor,
                             scope_name: str = "") -> tf.Tensor:

        scope_name = scope_name + "_predict" if scope_name else "predict"
        ndiffs = position_indices_t.get_shape().as_list()[0]

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype='float32', name="scope")

            obj_cmplx_t = self._reshapeComplexVariable(obj_v, self._obj.shape)
            obj_w_border_t = tf.pad(obj_cmplx_t,
                                    self._obj.border_shape,
                                    constant_values=self._obj.border_const)

            probe_cmplx_t = self._reshapeComplexVariable(
                probe_v, self._probe.shape)

            batch_obj_views_t = self._getPtychoObjViewStack(
                obj_w_border_t, probe_cmplx_t, position_indices_t)
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * probe_cmplx_t
            farfield_waves_t = propFF_t(exit_waves_t)
            return tf.reshape(
                tf.stack(
                    [tf.real(farfield_waves_t),
                     tf.imag(farfield_waves_t)]), [-1])
예제 #2
0
    def predict_second_order(self,
                             joint_v: tf.Variable,
                             position_indices_t: tf.Tensor,
                             scope_name: str = "") -> tf.Tensor:

        scope_name = scope_name + "_predict" if scope_name else "predict"
        ndiffs = position_indices_t.get_shape().as_list()[0]

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name="scope")

            joint_t = self._reshapeComplexVariable(joint_v,
                                                   self._joint_array_shape)

            obj_t_flat = joint_t[:self._obj.array.size]
            probe_t_flat = joint_t[self._obj.array.size:]
            obj_cmplx_t = tf.reshape(obj_t_flat, self._obj.shape)
            probe_cmplx_t = tf.reshape(probe_t_flat, self._probe.shape)

            obj_w_border_t = tf.pad(obj_cmplx_t,
                                    self._obj.border_shape,
                                    constant_values=self._obj.border_const)

            batch_obj_views_t = self._getPtychoObjViewStack(
                obj_w_border_t, probe_cmplx_t, position_indices_t)
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * probe_cmplx_t
            out_wavefronts_t = propFF_t(exit_waves_t)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1])
예제 #3
0
    def predict(self,
                position_indices_t: tf.Tensor,
                scope_name: str = "") -> tf.Tensor:
        ndiffs = position_indices_t.get_shape().as_list()[0]
        scope_name = scope_name + "_predict" if scope_name else "predict"

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name=scope)

            batch_rc_positions_indices = tf.gather(
                self._full_rc_positions_indices_t, position_indices_t)
            batch_obj_views_t = tf.gather(self._obj_views_all_t,
                                          batch_rc_positions_indices[:, 1])
            batch_phase_modulations_t = tf.gather(
                self._probe_phase_modulations_all_t,
                batch_rc_positions_indices[:, 0])

            batch_obj_views_t = batch_obj_views_t
            exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t
            exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t,
                                                         axis=-3))

            out_wavefronts_t = propFF_t(exit_waves_proj_t)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1], name=scope)
예제 #4
0
    def predict(self, position_indices_t: tf.Tensor):

        ndiffs = position_indices_t.get_shape().as_list()[0]

        if ndiffs == 0:
            return tf.zeros(shape=[], dtype='float32')

        batch_obj_views_t = tf.gather(self._obj_views_all_t, position_indices_t)
        batch_obj_views_t = fftshift_t(batch_obj_views_t)
        exit_waves_t = batch_obj_views_t * self.probe_cmplx_t
        out_wavefronts_t = propFF_t(exit_waves_t)
        amplitudes_t = tf.abs(out_wavefronts_t)
        if self.upsampling_factor > 1:
            amplitudes_t = self._downsample(amplitudes_t)
        return amplitudes_t
예제 #5
0
    def predict(self,
                position_indices_t: tf.Tensor,
                scope_name: str = "") -> tf.Tensor:

        scope_name = scope_name + "_predict" if scope_name else "predict"
        ndiffs = position_indices_t.get_shape().as_list()[0]

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype='complex64', name=scope)

            batch_obj_views_t = tf.gather(self._obj_views_all_t,
                                          position_indices_t)
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * self.probe_cmplx_t

            farfield_waves_t = propFF_t(exit_waves_t)
            return farfield_waves_t  #tf.reshape(tf.stack([tf.real(farfield_waves_t), tf.imag(farfield_waves_t)]), [-1])
예제 #6
0
    def predict(self, position_indices_t: tf.Tensor):
        ndiffs = position_indices_t.get_shape().as_list()[0]

        if ndiffs == 0:
            return tf.zeros(shape=[], dtype='float32')

        batch_rc_positions_indices = tf.gather(self._full_rc_positions_indices_t, position_indices_t)
        batch_obj_views_t = tf.gather(self._obj_views_all_t, batch_rc_positions_indices[:, 1])
        batch_phase_modulations_t = tf.gather(self._probe_phase_modulations_all_t, batch_rc_positions_indices[:, 0])

        batch_obj_views_t = batch_obj_views_t
        exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t
        exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t, axis=-3))

        out_wavefronts_t = propFF_t(exit_waves_proj_t)
        amplitudes_t = tf.abs(out_wavefronts_t)
        if self.upsampling_factor > 1:
            amplitudes_t = self._downsample(amplitudes_t)
        return amplitudes_t
예제 #7
0
    def predict(self,
                position_indices_t: tf.Tensor,
                scope_name: str = "") -> tf.Tensor:

        scope_name = scope_name + "_predict" if scope_name else "predict"
        ndiffs = position_indices_t.get_shape().as_list()[0]

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name=scope)

            batch_obj_views_t = tf.gather(self._obj_views_all_t,
                                          position_indices_t,
                                          name="test_gather")
            batch_obj_views_t = fftshift_t(batch_obj_views_t)
            exit_waves_t = batch_obj_views_t * self.probe_cmplx_t
            out_wavefronts_t = propFF_t(exit_waves_t)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1], name=scope)
예제 #8
0
    def predict_second_order(self,
                             obj_v: tf.Variable,
                             probe_v: tf.Variable,
                             position_indices_t: tf.Tensor,
                             scope_name: str = "") -> tf.Tensor:
        ndiffs = position_indices_t.get_shape().as_list()[0]
        scope_name = scope_name + "_predict" if scope_name else "predict"

        with tf.name_scope(scope_name) as scope:
            if ndiffs == 0:
                return tf.zeros(shape=[], dtype=self._dtype, name=scope)
            obj_cmplx_t = self._reshapeComplexVariable(obj_v, self._obj.shape)
            obj_w_border_t = tf.pad(obj_cmplx_t,
                                    self._obj.border_shape,
                                    constant_values=self._obj.border_const)

            probe_cmplx_t = self._reshapeComplexVariable(
                probe_v, self._probe.shape)

            batch_rc_positions_indices = tf.gather(
                self._full_rc_positions_indices_t, position_indices_t)
            batch_obj_views_t = self._getPtychoObjViewStack(
                obj_w_border_t, probe_cmplx_t,
                batch_rc_positions_indices[:, 1])  #position_indices_t)

            batch_phase_modulations_t = tf.gather(
                self._probe_phase_modulations_all_t,
                batch_rc_positions_indices[:, 0])

            batch_obj_views_t = batch_obj_views_t
            exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t
            exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t,
                                                         axis=-3))

            out_wavefronts_t = propFF_t(exit_waves_proj_t)
            amplitudes_t = tf.abs(out_wavefronts_t)
            if self.upsampling_factor > 1:
                amplitudes_t = self._downsample(amplitudes_t)
            return tf.reshape(amplitudes_t, [-1], name=scope)