# Correc the fwd model - not good here! tf_glob_real = tf.Variable(0., 'tf_glob_real') tf_glob_imag = tf.Variable(0., 'tf_glob_imag') #tf_norm = tf.complex(tf_glob_real, tf_glob_imag) tf_norm = tf.complex(0., 0.) '''Define Loss-function''' if (0): print('-------> Losstype is L1') tf_fidelity = tf.reduce_mean( (tf.abs((muscat.tf_meas) - tf_fwd) )) # allow a global phase parameter to avoid unwrapping effects elif (0): print('-------> Losstype mixed L2 ') tf_fidelity = tf.reduce_mean( tf_helper.tf_abssqr(tf.real(muscat.tf_meas) - tf.real(tf_fwd)) + tf_helper.tf_abssqr(tf.imag(muscat.tf_meas) - tf.imag(tf_fwd)) ) # allow a global phase parameter to avoid unwrapping effects else: print('-------> Losstype is L2') tf_fidelity = tf.reduce_mean( tf_helper.tf_abssqr(muscat.tf_meas - tf_fwd) ) # allow a global phase parameter to avoid unwrapping effects #tf_fidelity = tf.losses.mean_squared_error(tf.real(muscat.tf_meas), tf.real(tf_fwd)) #tf_fidelity += tf.losses.mean_squared_error(tf.imag(muscat.tf_meas), tf.imag(tf_fwd)) tf_loss = tf_fidelity + tf_negsqrloss + tf_regloss '''Define Optimizer''' if (0): print('Using ADAM optimizer') tf_optimizer = tf.train.AdamOptimizer(muscat.tf_learningrate) elif (0):
muscat.TF_obj_absorption, BetaVals=[muscat.dx, muscat.dy, muscat.dz], epsR=muscat.tf_eps, is_circ=True ) #Alernatively tf_total_variation_regularization # total_variation tf_negsqrloss = lambda_neg * reg.Reg_NegSqr(muscat.TF_obj) tf_negsqrloss += lambda_neg * reg.Reg_NegSqr(muscat.TF_obj_absorption) tf_globalphase = tf.Variable(0., tf.float32, name='var_phase') tf_globalabs = tf.Variable(1., tf.float32, name='var_abs') # #tf_fidelity = tf.reduce_sum((tf_helper.tf_abssqr(tf_fwd - (tf_meas/tf.cast(tf.abs(tf_globalabs), tf.complex64)*tf.exp(1j*tf.cast(tf_globalphase, tf.complex64)))))) # allow a global phase parameter to avoid unwrapping effects tf_fwd_corrected = tf_fwd / tf.cast( tf.abs(tf_globalabs), tf.complex64) * tf.exp( 1j * tf.cast(tf_globalphase, tf.complex64)) tf_fidelity = tf.reduce_mean( (tf_helper.tf_abssqr(muscat.tf_meas - tf_fwd_corrected) )) # allow a global phase parameter to avoid unwrapping effects tf_grads = tf.gradients(tf_fidelity, [muscat.TF_obj])[0] tf_loss = tf_fidelity + tf_negsqrloss + tf_tvloss #tf_negloss + tf_posloss + tf_tvloss '''Define Optimizer''' tf_optimizer = tf.train.AdamOptimizer(muscat.tf_learningrate) #tf_optimizer = tf.train.MomentumOptimizer(tf_learningrate, momentum = .9, use_nesterov=True) #tf_optimizer = tf.train.ProximalGradientDescentOptimizer(tf_learningrate) #tf_optimizer = tf.train.GradientDescentOptimizer(muscat.tf_learningrate) tf_lossop = tf_optimizer.minimize( tf_loss ) #, var_list = [muscat.TF_obj, muscat.TF_obj_absorption, tf_globalabs, tf_globalphase]) ''' Evaluate the model ''' sess = tf.Session()
'''Numpy to Tensorflow''' np_meas = matlab_val np_mean = np.mean(np_meas) '''Define Cost-function''' tf_tvloss = muscat.tf_lambda_tv*reg.Reg_TV(muscat.TF_obj, BetaVals = [muscat.dx,muscat.dy,muscat.dz], epsR=muscat.tf_eps, is_circ = True) #Alernatively tf_total_variation_regularization # total_variation tf_tvloss += muscat.tf_lambda_tv*reg.Reg_TV(muscat.TF_obj_absorption, BetaVals = [muscat.dx,muscat.dy,muscat.dz], epsR=muscat.tf_eps, is_circ = True) #Alernatively tf_total_variation_regularization # total_variation tf_negsqrloss = tf.constant(0) #tf_negsqrloss = lambda_neg*reg.Reg_NegSqr(muscat.TF_obj) #tf_negsqrloss += lambda_neg*reg.Reg_NegSqr(muscat.TF_obj_absorption) tf_globalphase = tf.Variable(0., tf.float32, name='var_phase') tf_globalabs = tf.Variable(.6, tf.float32, name='var_abs')# #tf_fidelity = tf.reduce_sum((tf_helper.tf_abssqr(tf_fwd - (tf_meas/tf.cast(tf.abs(tf_globalabs), tf.complex64)*tf.exp(1j*tf.cast(tf_globalphase, tf.complex64)))))) # allow a global phase parameter to avoid unwrapping effects tf_fwd_corrected = tf_fwd/tf.cast(tf.abs(tf_globalabs), tf.complex64)*tf.exp(1j*tf.cast(tf_globalphase, tf.complex64)) tf_fidelity = tf.reduce_mean((tf_helper.tf_abssqr(muscat.tf_meas - tf_fwd_corrected ))) # allow a global phase parameter to avoid unwrapping effects tf_grads = tf.gradients(tf_fidelity, [muscat.TF_obj])[0] tf_loss = tf_fidelity/np_mean + tf_tvloss #tf_negloss + tf_posloss + tf_tvloss '''Define Optimizer''' tf_optimizer = tf.train.AdamOptimizer(muscat.tf_learningrate) #tf_optimizer = tf.train.MomentumOptimizer(tf_learningrate, momentum = .9, use_nesterov=True) #tf_optimizer = tf.train.ProximalGradientDescentOptimizer(tf_learningrate) #tf_optimizer = tf.train.GradientDescentOptimizer(muscat.tf_learningrate) tf_lossop_obj = tf_optimizer.minimize(tf_loss, var_list = [muscat.TF_obj, tf_globalabs, tf_globalphase]) # muscat.TF_obj_absorption, if is_optimization_psf: # in case we want to do blind deconvolution tf_lossop_aberr = tf_optimizer.minimize(tf_loss, var_list = [muscat.TF_zernikefactors])
tf_negsqrloss = tf.constant(0.) # Correc the fwd model - not good here! tf_glob_real = tf.Variable(0.,'tf_glob_real') tf_glob_imag = tf.Variable(0.,'tf_glob_imag') #tf_norm = tf.complex(tf_glob_real, tf_glob_imag) tf_norm = tf.complex(0., 0.) '''Define Loss-function''' if(0): print('-------> Losstype is L1') tf_fidelity = tf.reduce_mean((tf.abs((muscat.tf_meas) - tf_fwd))) # allow a global phase parameter to avoid unwrapping effects elif(0): print('-------> Losstype mixed L2 ') tf_fidelity = tf.reduce_mean(tf_helper.tf_abssqr(tf.real(muscat.tf_meas) - tf.real(tf_fwd))+tf_helper.tf_abssqr(tf.imag(muscat.tf_meas) - tf.imag(tf_fwd))) # allow a global phase parameter to avoid unwrapping effects else: print('-------> Losstype is L2') tf_fidelity = tf.reduce_mean(tf_helper.tf_abssqr(muscat.tf_meas - tf_fwd)) # allow a global phase parameter to avoid unwrapping effects #tf_fidelity = tf.losses.mean_squared_error(tf.real(muscat.tf_meas), tf.real(tf_fwd)) #tf_fidelity += tf.losses.mean_squared_error(tf.imag(muscat.tf_meas), tf.imag(tf_fwd)) tf_loss = tf_fidelity + tf_negsqrloss + tf_regloss '''Define Optimizer''' if(0): print('Using ADAM optimizer') tf_optimizer = tf.train.AdamOptimizer(muscat.tf_learningrate) elif(0): print('Using ADAM optimizer') tf_optimizer = tf.train.AdadeltaOptimizer(muscat.tf_learningrate)
def computemodel(self): ''' Perform Multiple Scattering here 1.) Multiple Scattering is perfomed by slice-wise propagation the E-Field throught the sample 2.) Each Field has to be backprojected to the BFP 3.) Last step is to ceate a focus-stack and sum over all angles This is done for all illumination angles (coming from illumination NA simultaneasly)''' print("Buildup Q-PHASE Model ") ###### make sure, that the first dimension is "batch"-size; in this case it is the illumination number # @beniroquai It's common to have to batch dimensions first and not last.!!!!! # the following loop propagates the field sequentially to all different Z-planes ## propagate the field through the entire object for all angles simultaneously A_prop = np.transpose(self.A_input,[3, 0, 1, 2]) # ??????? what the hack is happening with transpose?! myprop = np.exp(1j * self.dphi) * (self.dphi > 0); # excludes the near field components in each step myprop = tf_helper.repmat4d(np.fft.fftshift(myprop), self.Nc) myprop = np.transpose(myprop, [3, 0, 1, 2]) # ??????? what the hack is happening with transpose?! RefrEffect = 1j * self.dz * self.k0 * self.RefrCos; # Precalculate the oblique effect on OPD to speed it up RefrEffect = np.transpose(RefrEffect, [3,0,1,2]) # for now orientate the dimensions as (alpha_illu, x, y, z) - because tensorflow takes the first dimension as batch size with tf.name_scope('Variable_assignment'): self.TF_A_input = tf.constant(A_prop, dtype=tf.complex64) self.TF_RefrEffect = tf.reshape(tf.constant(RefrEffect, dtype=tf.complex64), [self.Nc, 1, 1]) self.TF_myprop = tf.squeeze(tf.constant(myprop, dtype=tf.complex64)) self.TF_Po = tf.cast(tf.constant(self.Po), tf.complex64) self.TF_Zernikes = tf.constant(self.myzernikes, dtype=tf.float32) if(self.is_optimization_psf): self.TF_zernikefactors = tf.Variable(self.zernikefactors, dtype = tf.float32, name='var_zernikes') else: self.TF_zernikefactors = tf.constant(self.zernikefactors, dtype = tf.float32, name='const_zernikes') # TODO: Introduce the averraged RI along Z - MWeigert self.TF_A_prop = tf.squeeze(self.TF_A_input); self.U_z_list = [] # Initiliaze memory self.allInt = 0; self.allSumAmp = 0; self.TF_allSumAmp = tf.zeros([self.mysize[0], self.Nx, self.Ny], dtype=tf.complex64) self.TF_allSumInt = tf.zeros([self.mysize[0], self.Nx, self.Ny], dtype=tf.float32) self.tf_iterator = tf.Variable(1) # simulate multiple scattering through object with tf.name_scope('Fwd_Propagate'): for pz in range(0, self.mysize[0]): self.tf_iterator += self.tf_iterator #self.TF_A_prop = tf.Print(self.TF_A_prop, [self.tf_iterator], 'Prpagation step: ') with tf.name_scope('Refract'): TF_f_phase = tf.cast(self.TF_obj_phase_do[pz,:,:], tf.complex64) self.TF_f = tf.exp(1j*self.TF_RefrEffect*TF_f_phase) self.TF_A_prop = self.TF_A_prop * self.TF_f # refraction step with tf.name_scope('Propagate'): if(True): self.TF_A_prop = self.tf_convft_prop(self.TF_A_prop, self.TF_myprop) # in a final step limit this to the detection NA: self.TF_Po_aberr = tf.exp(1j*tf.cast(tf.reduce_sum(self.TF_zernikefactors*self.TF_Zernikes, axis=2), tf.complex64)) * self.TF_Po self.TF_A_prop = tf.ifft2d(tf.fft2d(self.TF_A_prop) * self.TF_Po_aberr) self.TF_myAllSlicePropagator = tf.constant(self.myAllSlicePropagator, dtype=tf.complex64) self.kzcoord = np.reshape(self.kz[self.Ic], [1, 1, 1, self.Nc]); # create Z-Stack by backpropagating Information in BFP to Z-Position # self.mid3D = ([np.int(np.ceil(self.A_input.shape[0] / 2) + 1), np.int(np.ceil(self.A_input.shape[1] / 2) + 1), np.int(np.ceil(self.mysize[0] / 2) + 1)]) self.mid3D = ([np.int(self.mysize[0]//2), np.int(self.A_input.shape[0] // 2), np.int(self.A_input.shape[1]//2)]) with tf.name_scope('Back_Propagate'): for pillu in range(0, self.Nc): with tf.name_scope('Back_Propagate_Step'): with tf.name_scope('Adjust'): # fprintf('BackpropaAngle no: #d\n',pillu); OneAmp = tf.expand_dims(self.TF_A_prop[pillu, :, :, ], 0) # Fancy backpropagation assuming what would be measured if the sample was moved under oblique illumination: # The trick is: First use conceptually the normal way # and then apply the XYZ shift using the Fourier shift theorem (corresponds to physically shifting the object volume, scattered field stays the same): self.TF_AdjustKXY = tf.squeeze(tf.conj(self.TF_A_input[pillu, :, :, ])) # tf.transpose(tf.conj(TF_A_input[pillu, :,:,]), [2, 1, 0]) # Maybe a bit of a dirty hack, but we first need to shift the zero coordinate to the center self.TF_AdjustKZ = tf.transpose(tf.constant(np.exp( 2 * np.pi * 1j * self.dz * np.reshape(np.arange(0, self.mysize[0], 1), [1, 1, self.mysize[0]]) * self.kzcoord[:, :, :,pillu]),dtype=tf.complex64), [2, 1, 0]); self.TF_allAmp = tf.ifft2d(tf.fft2d(OneAmp) * self.TF_myAllSlicePropagator) * self.TF_AdjustKZ * self.TF_AdjustKXY # * (TF_AdjustKZ); # 2x bfxfun. Propagates a single amplitude pattern back to the whole stack self.TF_allAmp = self.TF_allAmp * tf.exp(1j*tf.cast(tf.angle(self.TF_allAmp[self.mid3D[0],self.mid3D[1],self.mid3D[2]]), tf.complex64)) # Global Phases need to be adjusted at this step! Use the zero frequency if (0): with tf.name_scope('Propagate'): self.TF_allAmp_3dft = tf.fft3d(tf.expand_dims(self.TF_allAmp, axis=0)) self.TF_allAmp = self.TF_allAmp * tf.exp(-1j * tf.cast( tf_helper.tf_angle(self.TF_allAmp_3dft[self.mid3D[2], self.mid3D[1], self.mid3D[0]]), tf.complex64)); # Global Phases need to be adjusted at this step! Use the zero frequency # print(tf.exp(-1j*tf.cast(angle(TF_allAmp[self.mid3D[2], self.mid3D[0], self.mid3D[2]]), tf.complex64)).eval()) with tf.name_scope('Sum_Amps'): self.TF_allSumAmp = self.TF_allSumAmp + self.TF_allAmp; # Superpose the Amplitudes self.TF_allSumInt = self.TF_allSumInt + tf_helper.tf_abssqr(self.TF_allAmp) # print('Current illumination angle # is: ' + str(pillu)) # Normalize amplitude self.TF_allSumAmp = self.TF_allSumAmp / self.Nc # tf.reduce_max(TF_allSumAmp) # negate padding if self.is_padding: self.TF_allSumAmp = self.TF_allSumAmp[:,self.Nx//2-self.Nx//4:self.Nx//2+self.Nx//4, self.Ny//2-self.Ny//4:self.Ny//2+self.Ny//4] return self.TF_allSumAmp