Ejemplo n.º 1
0
    # Correc the fwd model - not good here!
    tf_glob_real = tf.Variable(0., 'tf_glob_real')
    tf_glob_imag = tf.Variable(0., 'tf_glob_imag')
    #tf_norm = tf.complex(tf_glob_real, tf_glob_imag)
    tf_norm = tf.complex(0., 0.)
    '''Define Loss-function'''
    if (0):
        print('-------> Losstype is L1')
        tf_fidelity = tf.reduce_mean(
            (tf.abs((muscat.tf_meas) - tf_fwd)
             ))  # allow a global phase parameter to avoid unwrapping effects
    elif (0):
        print('-------> Losstype mixed L2 ')
        tf_fidelity = tf.reduce_mean(
            tf_helper.tf_abssqr(tf.real(muscat.tf_meas) - tf.real(tf_fwd)) +
            tf_helper.tf_abssqr(tf.imag(muscat.tf_meas) - tf.imag(tf_fwd))
        )  # allow a global phase parameter to avoid unwrapping effects
    else:
        print('-------> Losstype is L2')
        tf_fidelity = tf.reduce_mean(
            tf_helper.tf_abssqr(muscat.tf_meas - tf_fwd)
        )  # allow a global phase parameter to avoid unwrapping effects
        #tf_fidelity =  tf.losses.mean_squared_error(tf.real(muscat.tf_meas), tf.real(tf_fwd))
        #tf_fidelity += tf.losses.mean_squared_error(tf.imag(muscat.tf_meas), tf.imag(tf_fwd))
    tf_loss = tf_fidelity + tf_negsqrloss + tf_regloss
    '''Define Optimizer'''
    if (0):
        print('Using ADAM optimizer')
        tf_optimizer = tf.train.AdamOptimizer(muscat.tf_learningrate)
    elif (0):
    muscat.TF_obj_absorption,
    BetaVals=[muscat.dx, muscat.dy, muscat.dz],
    epsR=muscat.tf_eps,
    is_circ=True
)  #Alernatively tf_total_variation_regularization # total_variation

tf_negsqrloss = lambda_neg * reg.Reg_NegSqr(muscat.TF_obj)
tf_negsqrloss += lambda_neg * reg.Reg_NegSqr(muscat.TF_obj_absorption)
tf_globalphase = tf.Variable(0., tf.float32, name='var_phase')
tf_globalabs = tf.Variable(1., tf.float32, name='var_abs')  #
#tf_fidelity = tf.reduce_sum((tf_helper.tf_abssqr(tf_fwd  - (tf_meas/tf.cast(tf.abs(tf_globalabs), tf.complex64)*tf.exp(1j*tf.cast(tf_globalphase, tf.complex64)))))) # allow a global phase parameter to avoid unwrapping effects
tf_fwd_corrected = tf_fwd / tf.cast(
    tf.abs(tf_globalabs), tf.complex64) * tf.exp(
        1j * tf.cast(tf_globalphase, tf.complex64))
tf_fidelity = tf.reduce_mean(
    (tf_helper.tf_abssqr(muscat.tf_meas - tf_fwd_corrected)
     ))  # allow a global phase parameter to avoid unwrapping effects
tf_grads = tf.gradients(tf_fidelity, [muscat.TF_obj])[0]

tf_loss = tf_fidelity + tf_negsqrloss + tf_tvloss  #tf_negloss + tf_posloss + tf_tvloss
'''Define Optimizer'''
tf_optimizer = tf.train.AdamOptimizer(muscat.tf_learningrate)
#tf_optimizer = tf.train.MomentumOptimizer(tf_learningrate, momentum = .9, use_nesterov=True)
#tf_optimizer = tf.train.ProximalGradientDescentOptimizer(tf_learningrate)
#tf_optimizer = tf.train.GradientDescentOptimizer(muscat.tf_learningrate)

tf_lossop = tf_optimizer.minimize(
    tf_loss
)  #, var_list = [muscat.TF_obj, muscat.TF_obj_absorption, tf_globalabs, tf_globalphase])
''' Evaluate the model '''
sess = tf.Session()
'''Numpy to Tensorflow'''
np_meas = matlab_val
np_mean = np.mean(np_meas)

'''Define Cost-function'''
tf_tvloss = muscat.tf_lambda_tv*reg.Reg_TV(muscat.TF_obj, BetaVals = [muscat.dx,muscat.dy,muscat.dz], epsR=muscat.tf_eps, is_circ = True)  #Alernatively tf_total_variation_regularization # total_variation
tf_tvloss += muscat.tf_lambda_tv*reg.Reg_TV(muscat.TF_obj_absorption, BetaVals = [muscat.dx,muscat.dy,muscat.dz], epsR=muscat.tf_eps, is_circ = True)  #Alernatively tf_total_variation_regularization # total_variation

tf_negsqrloss = tf.constant(0)
#tf_negsqrloss = lambda_neg*reg.Reg_NegSqr(muscat.TF_obj)
#tf_negsqrloss += lambda_neg*reg.Reg_NegSqr(muscat.TF_obj_absorption)
tf_globalphase = tf.Variable(0., tf.float32, name='var_phase')
tf_globalabs = tf.Variable(.6, tf.float32, name='var_abs')# 
#tf_fidelity = tf.reduce_sum((tf_helper.tf_abssqr(tf_fwd  - (tf_meas/tf.cast(tf.abs(tf_globalabs), tf.complex64)*tf.exp(1j*tf.cast(tf_globalphase, tf.complex64)))))) # allow a global phase parameter to avoid unwrapping effects
tf_fwd_corrected = tf_fwd/tf.cast(tf.abs(tf_globalabs), tf.complex64)*tf.exp(1j*tf.cast(tf_globalphase, tf.complex64))
tf_fidelity = tf.reduce_mean((tf_helper.tf_abssqr(muscat.tf_meas - tf_fwd_corrected ))) # allow a global phase parameter to avoid unwrapping effects
tf_grads = tf.gradients(tf_fidelity, [muscat.TF_obj])[0]

tf_loss = tf_fidelity/np_mean + tf_tvloss #tf_negloss + tf_posloss + tf_tvloss

'''Define Optimizer'''
tf_optimizer = tf.train.AdamOptimizer(muscat.tf_learningrate)
#tf_optimizer = tf.train.MomentumOptimizer(tf_learningrate, momentum = .9, use_nesterov=True)
#tf_optimizer = tf.train.ProximalGradientDescentOptimizer(tf_learningrate)
#tf_optimizer = tf.train.GradientDescentOptimizer(muscat.tf_learningrate)

tf_lossop_obj = tf_optimizer.minimize(tf_loss, var_list = [muscat.TF_obj, tf_globalabs, tf_globalphase]) # muscat.TF_obj_absorption, 
if is_optimization_psf: # in case we want to do blind deconvolution
    tf_lossop_aberr = tf_optimizer.minimize(tf_loss, var_list = [muscat.TF_zernikefactors])

      tf_negsqrloss =  tf.constant(0.)
  
  
  # Correc the fwd model - not good here!
  tf_glob_real = tf.Variable(0.,'tf_glob_real')
  tf_glob_imag = tf.Variable(0.,'tf_glob_imag')
  #tf_norm = tf.complex(tf_glob_real, tf_glob_imag)
  tf_norm = tf.complex(0., 0.)
  
  '''Define Loss-function'''
  if(0):
      print('-------> Losstype is L1')
      tf_fidelity = tf.reduce_mean((tf.abs((muscat.tf_meas) - tf_fwd))) # allow a global phase parameter to avoid unwrapping effects
  elif(0):
      print('-------> Losstype mixed L2 ')
      tf_fidelity = tf.reduce_mean(tf_helper.tf_abssqr(tf.real(muscat.tf_meas) - tf.real(tf_fwd))+tf_helper.tf_abssqr(tf.imag(muscat.tf_meas) - tf.imag(tf_fwd))) # allow a global phase parameter to avoid unwrapping effects
  else:
      print('-------> Losstype is L2')
      tf_fidelity = tf.reduce_mean(tf_helper.tf_abssqr(muscat.tf_meas - tf_fwd)) # allow a global phase parameter to avoid unwrapping effects
      #tf_fidelity =  tf.losses.mean_squared_error(tf.real(muscat.tf_meas), tf.real(tf_fwd))
      #tf_fidelity += tf.losses.mean_squared_error(tf.imag(muscat.tf_meas), tf.imag(tf_fwd))
  tf_loss = tf_fidelity + tf_negsqrloss + tf_regloss
 
  
  '''Define Optimizer'''     
  if(0):
      print('Using ADAM optimizer')            
      tf_optimizer = tf.train.AdamOptimizer(muscat.tf_learningrate)
  elif(0):
      print('Using ADAM optimizer')            
      tf_optimizer = tf.train.AdadeltaOptimizer(muscat.tf_learningrate)
Ejemplo n.º 5
0
    def computemodel(self):
        ''' Perform Multiple Scattering here
        1.) Multiple Scattering is perfomed by slice-wise propagation the E-Field throught the sample
        2.) Each Field has to be backprojected to the BFP
        3.) Last step is to ceate a focus-stack and sum over all angles

        This is done for all illumination angles (coming from illumination NA
        simultaneasly)'''


        print("Buildup Q-PHASE Model ")
        ###### make sure, that the first dimension is "batch"-size; in this case it is the illumination number
        # @beniroquai It's common to have to batch dimensions first and not last.!!!!!
        # the following loop propagates the field sequentially to all different Z-planes

        ## propagate the field through the entire object for all angles simultaneously
        A_prop = np.transpose(self.A_input,[3, 0, 1, 2])  # ??????? what the hack is happening with transpose?!

        myprop = np.exp(1j * self.dphi) * (self.dphi > 0);  # excludes the near field components in each step
        myprop = tf_helper.repmat4d(np.fft.fftshift(myprop), self.Nc)
        myprop = np.transpose(myprop, [3, 0, 1, 2])  # ??????? what the hack is happening with transpose?!

        RefrEffect = 1j * self.dz * self.k0 * self.RefrCos;  # Precalculate the oblique effect on OPD to speed it up
        RefrEffect = np.transpose(RefrEffect, [3,0,1,2])

        # for now orientate the dimensions as (alpha_illu, x, y, z) - because tensorflow takes the first dimension as batch size
        with tf.name_scope('Variable_assignment'):
            self.TF_A_input = tf.constant(A_prop, dtype=tf.complex64)
            self.TF_RefrEffect = tf.reshape(tf.constant(RefrEffect, dtype=tf.complex64), [self.Nc, 1, 1])
            self.TF_myprop = tf.squeeze(tf.constant(myprop, dtype=tf.complex64))
            self.TF_Po = tf.cast(tf.constant(self.Po), tf.complex64)
            self.TF_Zernikes = tf.constant(self.myzernikes, dtype=tf.float32)
            
            if(self.is_optimization_psf):
                self.TF_zernikefactors = tf.Variable(self.zernikefactors, dtype = tf.float32, name='var_zernikes')
            else:
                self.TF_zernikefactors = tf.constant(self.zernikefactors, dtype = tf.float32, name='const_zernikes')

        # TODO: Introduce the averraged RI along Z - MWeigert

        self.TF_A_prop = tf.squeeze(self.TF_A_input);
        self.U_z_list = []

        # Initiliaze memory
        self.allInt = 0;
        self.allSumAmp = 0;
        self.TF_allSumAmp = tf.zeros([self.mysize[0], self.Nx, self.Ny], dtype=tf.complex64)
        self.TF_allSumInt = tf.zeros([self.mysize[0], self.Nx, self.Ny], dtype=tf.float32)


        self.tf_iterator = tf.Variable(1)
        # simulate multiple scattering through object
        with tf.name_scope('Fwd_Propagate'):
            for pz in range(0, self.mysize[0]):
                self.tf_iterator += self.tf_iterator
                #self.TF_A_prop = tf.Print(self.TF_A_prop, [self.tf_iterator], 'Prpagation step: ')
                with tf.name_scope('Refract'):
                    TF_f_phase = tf.cast(self.TF_obj_phase_do[pz,:,:], tf.complex64)
                    self.TF_f = tf.exp(1j*self.TF_RefrEffect*TF_f_phase)
                    self.TF_A_prop = self.TF_A_prop * self.TF_f  # refraction step

                with tf.name_scope('Propagate'):
                    if(True):
                        self.TF_A_prop = self.tf_convft_prop(self.TF_A_prop, self.TF_myprop)

        # in a final step limit this to the detection NA:
        self.TF_Po_aberr = tf.exp(1j*tf.cast(tf.reduce_sum(self.TF_zernikefactors*self.TF_Zernikes, axis=2), tf.complex64)) * self.TF_Po
        self.TF_A_prop = tf.ifft2d(tf.fft2d(self.TF_A_prop) * self.TF_Po_aberr)

        self.TF_myAllSlicePropagator = tf.constant(self.myAllSlicePropagator, dtype=tf.complex64)
        self.kzcoord = np.reshape(self.kz[self.Ic], [1, 1, 1, self.Nc]);

        # create Z-Stack by backpropagating Information in BFP to Z-Position
        # self.mid3D = ([np.int(np.ceil(self.A_input.shape[0] / 2) + 1), np.int(np.ceil(self.A_input.shape[1] / 2) + 1), np.int(np.ceil(self.mysize[0] / 2) + 1)])
        self.mid3D = ([np.int(self.mysize[0]//2), np.int(self.A_input.shape[0] // 2), np.int(self.A_input.shape[1]//2)])

        with tf.name_scope('Back_Propagate'):
            for pillu in range(0, self.Nc):
                with tf.name_scope('Back_Propagate_Step'):
                    with tf.name_scope('Adjust'):
                        #    fprintf('BackpropaAngle no: #d\n',pillu);
                        OneAmp = tf.expand_dims(self.TF_A_prop[pillu, :, :, ], 0)

                        # Fancy backpropagation assuming what would be measured if the sample was moved under oblique illumination:
                        # The trick is: First use conceptually the normal way
                        # and then apply the XYZ shift using the Fourier shift theorem (corresponds to physically shifting the object volume, scattered field stays the same):
                        self.TF_AdjustKXY = tf.squeeze(tf.conj(self.TF_A_input[pillu, :,
                                                               :, ]))  # tf.transpose(tf.conj(TF_A_input[pillu, :,:,]), [2, 1, 0]) # Maybe a bit of a dirty hack, but we first need to shift the zero coordinate to the center
                        self.TF_AdjustKZ = tf.transpose(tf.constant(np.exp(
                            2 * np.pi * 1j * self.dz * np.reshape(np.arange(0, self.mysize[0], 1),
                                  [1, 1, self.mysize[0]]) * self.kzcoord[:, :, :,pillu]),dtype=tf.complex64), [2, 1, 0]);
                        self.TF_allAmp = tf.ifft2d(tf.fft2d(OneAmp) * self.TF_myAllSlicePropagator) * self.TF_AdjustKZ * self.TF_AdjustKXY  # * (TF_AdjustKZ);  # 2x bfxfun.  Propagates a single amplitude pattern back to the whole stack
                        self.TF_allAmp = self.TF_allAmp * tf.exp(1j*tf.cast(tf.angle(self.TF_allAmp[self.mid3D[0],self.mid3D[1],self.mid3D[2]]), tf.complex64)) # Global Phases need to be adjusted at this step!  Use the zero frequency
                        
                    if (0):
                        with tf.name_scope('Propagate'):
                            self.TF_allAmp_3dft = tf.fft3d(tf.expand_dims(self.TF_allAmp, axis=0))
                            self.TF_allAmp = self.TF_allAmp * tf.exp(-1j * tf.cast(
                                tf_helper.tf_angle(self.TF_allAmp_3dft[self.mid3D[2], self.mid3D[1], self.mid3D[0]]),
                                tf.complex64));  # Global Phases need to be adjusted at this step!  Use the zero frequency


                    # print(tf.exp(-1j*tf.cast(angle(TF_allAmp[self.mid3D[2], self.mid3D[0], self.mid3D[2]]), tf.complex64)).eval())

                    with tf.name_scope('Sum_Amps'):
                        self.TF_allSumAmp = self.TF_allSumAmp + self.TF_allAmp;  # Superpose the Amplitudes
                        self.TF_allSumInt = self.TF_allSumInt + tf_helper.tf_abssqr(self.TF_allAmp)

                    # print('Current illumination angle # is: ' + str(pillu))

        # Normalize amplitude
        self.TF_allSumAmp = self.TF_allSumAmp / self.Nc  # tf.reduce_max(TF_allSumAmp)

        # negate padding        
        if self.is_padding:
            self.TF_allSumAmp = self.TF_allSumAmp[:,self.Nx//2-self.Nx//4:self.Nx//2+self.Nx//4, self.Ny//2-self.Ny//4:self.Ny//2+self.Ny//4]
            
        return self.TF_allSumAmp