def predict_second_order(self, obj_v: tf.Variable, probe_v: tf.Variable, position_indices_t: tf.Tensor, scope_name: str = "") -> tf.Tensor: scope_name = scope_name + "_predict" if scope_name else "predict" ndiffs = position_indices_t.get_shape().as_list()[0] with tf.name_scope(scope_name) as scope: if ndiffs == 0: return tf.zeros(shape=[], dtype='float32', name="scope") obj_cmplx_t = self._reshapeComplexVariable(obj_v, self._obj.shape) obj_w_border_t = tf.pad(obj_cmplx_t, self._obj.border_shape, constant_values=self._obj.border_const) probe_cmplx_t = self._reshapeComplexVariable( probe_v, self._probe.shape) batch_obj_views_t = self._getPtychoObjViewStack( obj_w_border_t, probe_cmplx_t, position_indices_t) batch_obj_views_t = fftshift_t(batch_obj_views_t) exit_waves_t = batch_obj_views_t * probe_cmplx_t farfield_waves_t = propFF_t(exit_waves_t) return tf.reshape( tf.stack( [tf.real(farfield_waves_t), tf.imag(farfield_waves_t)]), [-1])
def predict_second_order(self, joint_v: tf.Variable, position_indices_t: tf.Tensor, scope_name: str = "") -> tf.Tensor: scope_name = scope_name + "_predict" if scope_name else "predict" ndiffs = position_indices_t.get_shape().as_list()[0] with tf.name_scope(scope_name) as scope: if ndiffs == 0: return tf.zeros(shape=[], dtype=self._dtype, name="scope") joint_t = self._reshapeComplexVariable(joint_v, self._joint_array_shape) obj_t_flat = joint_t[:self._obj.array.size] probe_t_flat = joint_t[self._obj.array.size:] obj_cmplx_t = tf.reshape(obj_t_flat, self._obj.shape) probe_cmplx_t = tf.reshape(probe_t_flat, self._probe.shape) obj_w_border_t = tf.pad(obj_cmplx_t, self._obj.border_shape, constant_values=self._obj.border_const) batch_obj_views_t = self._getPtychoObjViewStack( obj_w_border_t, probe_cmplx_t, position_indices_t) batch_obj_views_t = fftshift_t(batch_obj_views_t) exit_waves_t = batch_obj_views_t * probe_cmplx_t out_wavefronts_t = propFF_t(exit_waves_t) amplitudes_t = tf.abs(out_wavefronts_t) if self.upsampling_factor > 1: amplitudes_t = self._downsample(amplitudes_t) return tf.reshape(amplitudes_t, [-1])
def predict(self, position_indices_t: tf.Tensor, scope_name: str = "") -> tf.Tensor: ndiffs = position_indices_t.get_shape().as_list()[0] scope_name = scope_name + "_predict" if scope_name else "predict" with tf.name_scope(scope_name) as scope: if ndiffs == 0: return tf.zeros(shape=[], dtype=self._dtype, name=scope) batch_rc_positions_indices = tf.gather( self._full_rc_positions_indices_t, position_indices_t) batch_obj_views_t = tf.gather(self._obj_views_all_t, batch_rc_positions_indices[:, 1]) batch_phase_modulations_t = tf.gather( self._probe_phase_modulations_all_t, batch_rc_positions_indices[:, 0]) batch_obj_views_t = batch_obj_views_t exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t, axis=-3)) out_wavefronts_t = propFF_t(exit_waves_proj_t) amplitudes_t = tf.abs(out_wavefronts_t) if self.upsampling_factor > 1: amplitudes_t = self._downsample(amplitudes_t) return tf.reshape(amplitudes_t, [-1], name=scope)
def predict(self, position_indices_t: tf.Tensor): ndiffs = position_indices_t.get_shape().as_list()[0] if ndiffs == 0: return tf.zeros(shape=[], dtype='float32') batch_obj_views_t = tf.gather(self._obj_views_all_t, position_indices_t) batch_obj_views_t = fftshift_t(batch_obj_views_t) exit_waves_t = batch_obj_views_t * self.probe_cmplx_t out_wavefronts_t = propFF_t(exit_waves_t) amplitudes_t = tf.abs(out_wavefronts_t) if self.upsampling_factor > 1: amplitudes_t = self._downsample(amplitudes_t) return amplitudes_t
def predict(self, position_indices_t: tf.Tensor, scope_name: str = "") -> tf.Tensor: scope_name = scope_name + "_predict" if scope_name else "predict" ndiffs = position_indices_t.get_shape().as_list()[0] with tf.name_scope(scope_name) as scope: if ndiffs == 0: return tf.zeros(shape=[], dtype='complex64', name=scope) batch_obj_views_t = tf.gather(self._obj_views_all_t, position_indices_t) batch_obj_views_t = fftshift_t(batch_obj_views_t) exit_waves_t = batch_obj_views_t * self.probe_cmplx_t farfield_waves_t = propFF_t(exit_waves_t) return farfield_waves_t #tf.reshape(tf.stack([tf.real(farfield_waves_t), tf.imag(farfield_waves_t)]), [-1])
def predict(self, position_indices_t: tf.Tensor): ndiffs = position_indices_t.get_shape().as_list()[0] if ndiffs == 0: return tf.zeros(shape=[], dtype='float32') batch_rc_positions_indices = tf.gather(self._full_rc_positions_indices_t, position_indices_t) batch_obj_views_t = tf.gather(self._obj_views_all_t, batch_rc_positions_indices[:, 1]) batch_phase_modulations_t = tf.gather(self._probe_phase_modulations_all_t, batch_rc_positions_indices[:, 0]) batch_obj_views_t = batch_obj_views_t exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t, axis=-3)) out_wavefronts_t = propFF_t(exit_waves_proj_t) amplitudes_t = tf.abs(out_wavefronts_t) if self.upsampling_factor > 1: amplitudes_t = self._downsample(amplitudes_t) return amplitudes_t
def predict(self, position_indices_t: tf.Tensor, scope_name: str = "") -> tf.Tensor: scope_name = scope_name + "_predict" if scope_name else "predict" ndiffs = position_indices_t.get_shape().as_list()[0] with tf.name_scope(scope_name) as scope: if ndiffs == 0: return tf.zeros(shape=[], dtype=self._dtype, name=scope) batch_obj_views_t = tf.gather(self._obj_views_all_t, position_indices_t, name="test_gather") batch_obj_views_t = fftshift_t(batch_obj_views_t) exit_waves_t = batch_obj_views_t * self.probe_cmplx_t out_wavefronts_t = propFF_t(exit_waves_t) amplitudes_t = tf.abs(out_wavefronts_t) if self.upsampling_factor > 1: amplitudes_t = self._downsample(amplitudes_t) return tf.reshape(amplitudes_t, [-1], name=scope)
def predict_second_order(self, obj_v: tf.Variable, probe_v: tf.Variable, position_indices_t: tf.Tensor, scope_name: str = "") -> tf.Tensor: ndiffs = position_indices_t.get_shape().as_list()[0] scope_name = scope_name + "_predict" if scope_name else "predict" with tf.name_scope(scope_name) as scope: if ndiffs == 0: return tf.zeros(shape=[], dtype=self._dtype, name=scope) obj_cmplx_t = self._reshapeComplexVariable(obj_v, self._obj.shape) obj_w_border_t = tf.pad(obj_cmplx_t, self._obj.border_shape, constant_values=self._obj.border_const) probe_cmplx_t = self._reshapeComplexVariable( probe_v, self._probe.shape) batch_rc_positions_indices = tf.gather( self._full_rc_positions_indices_t, position_indices_t) batch_obj_views_t = self._getPtychoObjViewStack( obj_w_border_t, probe_cmplx_t, batch_rc_positions_indices[:, 1]) #position_indices_t) batch_phase_modulations_t = tf.gather( self._probe_phase_modulations_all_t, batch_rc_positions_indices[:, 0]) batch_obj_views_t = batch_obj_views_t exit_waves_t = batch_obj_views_t * self.probe_cmplx_t * batch_phase_modulations_t exit_waves_proj_t = fftshift_t(tf.reduce_sum(exit_waves_t, axis=-3)) out_wavefronts_t = propFF_t(exit_waves_proj_t) amplitudes_t = tf.abs(out_wavefronts_t) if self.upsampling_factor > 1: amplitudes_t = self._downsample(amplitudes_t) return tf.reshape(amplitudes_t, [-1], name=scope)