def computesys(self,
                   obj,
                   is_zernike=False,
                   is_padding=False,
                   dropout_prob=1):
        """ This computes the FWD-graph of the Q-PHASE microscope;
        1.) Compute the physical dimensions
        2.) Compute the sampling for the waves
        3.) Create the illumination waves depending on System's properties

        ##### IMPORTANT! ##### 
        The ordering of the channels is as follows:
            Nillu, Nz, Nx, Ny
        """
        # define whether we want to pad the experiment
        self.is_padding = is_padding

        if (is_padding):
            print('WARNING: Padding is not yet working correctly!!!!!!!!')
            # add padding in X/Y to avoid wrap-arounds
            self.Nx = self.Nx * 2
            self.Ny = self.Ny * 2
            self.mysize = np.array((self.Nz, self.Nx, self.Ny))
            self.obj = obj
            self.dx = self.dx
            self.dy = self.dy

        else:
            self.mysize = np.array((self.Nz, self.Nx, self.Ny))
            self.obj = obj

        # Decide whether we wan'T to optimize or simply execute the model
        if (self.is_optimization):
            if is_padding:
                # Pad object with zeros along X/Y
                obj_tmp = np.zeros(self.mysize)  # + 1j*np.zeros(muscat.mysize)
                obj_tmp[:, self.Nx // 2 - self.Nx // 4:self.Nx // 2 +
                        self.Nx // 4, self.Ny // 2 -
                        self.Ny // 4:self.Ny // 2 + self.Ny // 4] = self.obj
                self.obj = obj_tmp
            # in case one wants to use this as a fwd-model for an inverse problem
            self.TF_obj_phase = tf.Variable(self.obj,
                                            dtype=tf.float32,
                                            name='Object_Variable')
            self.TF_obj_phase_do = tf.nn.dropout(
                self.TF_obj_phase,
                keep_prob=dropout_prob)  # eventually apply dropout

        else:
            # Variables of the computational graph
            if is_padding:
                # Pad object with zeros along X/Y
                obj_tmp = np.zeros(self.mysize)  # + 1j*np.zeros(muscat.mysize)
                obj_tmp[:, self.Nx // 2 - self.Nx // 4:self.Nx // 2 +
                        self.Nx // 4, self.Ny // 2 -
                        self.Ny // 4:self.Ny // 2 + self.Ny // 4] = self.obj
                self.obj = obj_tmp
            self.TF_obj_phase_do = tf.constant(self.obj,
                                               dtype=tf.float32,
                                               name='Object_const')

        ## Establish normalized coordinates.
        #-----------------------------------------
        vxx = tf_helper.xx(
            (self.mysize[1], self.mysize[2]),
            'freq') * self.lambdaM * self.nEmbb / (self.dx * self.NAo)
        # normalized optical coordinates in X
        vyy = tf_helper.yy(
            (self.mysize[1], self.mysize[2]),
            'freq') * self.lambdaM * self.nEmbb / (self.dy * self.NAo)
        # normalized optical coordinates in Y

        # AbbeLimit=lambda0/NAo;  # Rainer's Method
        # RelFreq = rr(mysize,'freq')*AbbeLimit/dx;  # Is not generally right (dx and dy)
        self.RelFreq = np.sqrt(tf_helper.abssqr(vxx) + tf_helper.abssqr(vyy))
        # spanns the frequency grid of normalized pupil coordinates
        self.Po = self.RelFreq < 1.0
        # Create the pupil of the objective lens

        # Precomputing the first 9 zernike coefficients
        self.myzernikes = np.zeros(
            (self.Po.shape[0], self.Po.shape[1],
             self.nzernikes)) + 1j * np.zeros(
                 (self.Po.shape[0], self.Po.shape[1], self.nzernikes))
        r, theta = zern.cart2pol(vxx, vyy)
        for i in range(0, self.nzernikes):
            self.myzernikes[:, :, i] = zern.zernike(
                r, theta, i + 1, norm=False)  # or 8 in X-direction

        # eventually introduce a phase factor to approximate the experimental data better
        self.Po = self.Po  # Need to shift it before using as a low-pass filter    Po=np.ones((np.shape(Po)))
        if is_zernike:
            print(
                '----------> Be aware: We are taking aberrations into account!'
            )
            # Assuming: System has coma along X-direction
            self.myaberration = np.sum(self.zernikefactors * self.myzernikes,
                                       axis=2)
            self.Po = 1. * self.Po

        # Prepare the normalized spatial-frequency grid.
        self.S = self.NAc / self.NAo
        # Coherence factor
        self.Ic = self.RelFreq <= self.S
        myIntensityFactor = 70
        self.Ic_map = np.cos((myIntensityFactor * tf_helper.xx(
            (self.Nx, self.Ny), mode='freq')**2 +
                              myIntensityFactor * tf_helper.yy(
                                  (self.Nx, self.Ny), mode='freq')**2))**2
        self.Ic = self.Ic * self.Ic_map  # weight the intensity in the condenser aperture, unlikely to be uniform
        print('We are weighing the Intensity int the illu-pupil!')

        if hasattr(self, 'NAci'):
            if self.NAci != None and self.NAci > 0:
                #print('I detected a darkfield illumination aperture!')
                self.S_o = self.NAc / self.NAo
                # Coherence factor
                self.S_i = self.NAci / self.NAo
                # Coherence factor
                self.Ic = (1. * (self.RelFreq < self.S_o) * 1. *
                           (self.RelFreq > self.S_i)
                           ) > 0  # Create the pupil of the condenser plane

        # Shift the pupil in X-direction (optical missalignment)
        if hasattr(self, 'shiftIcX'):
            if self.shiftIcX != None:
                print('Shifting the illumination in X by: ' +
                      str(self.shiftIcX) + ' Pixel')
                self.Ic = np.roll(self.Ic, self.shiftIcX, axis=1)

        # Shift the pupil in Y-direction (optical missalignment)
        if hasattr(self, 'shiftIcY'):
            if self.shiftIcY != None:
                print('Shifting the illumination in Y by: ' +
                      str(self.shiftIcY) + ' Pixel')
                self.Ic = np.roll(self.Ic, self.shiftIcY, axis=0)

        ## Forward propagator  (Ewald sphere based) DO NOT USE NORMALIZED COORDINATES HERE
        self.kxysqr = (tf_helper.abssqr(
            tf_helper.xx((self.mysize[1], self.mysize[2]), 'freq') /
            self.dx) + tf_helper.abssqr(
                tf_helper.yy(
                    (self.mysize[1], self.mysize[2]), 'freq') / self.dy)) + 0j
        self.k0 = 1 / self.lambdaM
        self.kzsqr = tf_helper.abssqr(self.k0) - self.kxysqr
        self.kz = np.sqrt(self.kzsqr)
        self.kz[self.kzsqr < 0] = 0
        self.dphi = 2 * np.pi * self.kz * self.dz
        # exp(1i*kz*dz) would be the propagator for one slice

        ## Get a list of vector coordinates corresponding to the pixels in the mask
        xfreq = tf_helper.xx((self.mysize[1], self.mysize[2]), 'freq')
        yfreq = tf_helper.yy((self.mysize[1], self.mysize[2]), 'freq')
        self.Nc = np.sum(self.Ic > 0)
        print('Number of Illumination Angles / Plane waves: ' + str(self.Nc))

        # Calculate the computatonal grid/sampling
        self.kxcoord = np.reshape(xfreq[self.Ic > 0], [1, 1, 1, self.Nc])
        # NA-positions in condenser aperture plane in x-direction
        self.kycoord = np.reshape(yfreq[self.Ic > 0], [1, 1, 1, self.Nc])
        # NA-positions in condenser aperture plane in y-direction
        self.RefrCos = np.reshape(self.k0 / self.kz[self.Ic > 0],
                                  [1, 1, 1, self.Nc])
        # 1/cosine used for the application of the refractive index steps to acount for longer OPD in medium under an oblique illumination angle

        ## Generate the illumination amplitudes
        self.intensityweights = self.Ic[self.Ic > 0]
        self.A_input = self.intensityweights * np.exp(
            (2 * np.pi * 1j) *
            (self.kxcoord *
             tf_helper.repmat4d(tf_helper.xx(
                 (self.mysize[1], self.mysize[2])), self.Nc) + self.kycoord *
             tf_helper.repmat4d(tf_helper.yy(
                 (self.mysize[1], self.mysize[2])), self.Nc))
        )  # Corresponds to a plane wave under many oblique illumination angles - bfxfun

        ## propagate field to z-stack and sum over all illumination angles
        self.Alldphi = -np.reshape(np.arange(
            0, self.mysize[0], 1), [1, 1, self.mysize[0]]) * np.repeat(
                self.dphi[:, :, np.newaxis], self.mysize[0], axis=2)

        # Ordinary backpropagation. This is NOT what we are interested in:
        self.myAllSlicePropagator = np.transpose(
            np.exp(1j * self.Alldphi) * (np.repeat(
                self.dphi[:, :, np.newaxis], self.mysize[0], axis=2) > 0),
            [2, 0, 1])
    def computesys(self, obj, is_padding=False, is_tomo=False, dropout_prob=1):
        """ This computes the FWD-graph of the Q-PHASE microscope;
        1.) Compute the physical dimensions
        2.) Compute the sampling for the waves
        3.) Create the illumination waves depending on System's properties
 
        ##### IMPORTANT! ##### 
        The ordering of the channels is as follows:
            Nillu, Nz, Nx, Ny
        """
        # define whether we want to pad the experiment
        self.is_padding = is_padding
        self.is_tomo = is_tomo

        self.obj = obj
        if (is_padding):
            print(
                '--------->WARNING: Padding is not yet working correctly!!!!!!!!'
            )
            # add padding in X/Y to avoid wrap-arounds
            self.mysize_old = np.array((self.Nz, self.Nx, self.Ny))
            self.Nx = self.Nx * 2
            self.Ny = self.Ny * 2
            self.mysize = np.array((self.Nz, self.Nx, self.Ny))
            self.obj = obj
            self.dx = self.dx
            self.dy = self.dy
        else:
            self.mysize = np.array((self.Nz, self.Nx, self.Ny))
            self.mysize_old = self.mysize

        # Decide whether we wan'T to optimize or simply execute the model
        if (self.is_optimization == 1):
            #self.TF_obj = tf.Variable(np.real(self.obj), dtype=tf.float32, name='Object_Variable')
            #self.TF_obj_absorption = tf.Variable(np.imag(self.obj), dtype=tf.float32, name='Object_Variable')
            with tf.variable_scope("Complex_Object"):
                self.TF_obj = tf.get_variable('Object_Variable_Real',
                                              dtype=tf.float32,
                                              initializer=np.float32(
                                                  np.real(self.obj)))
                self.TF_obj_absorption = tf.get_variable(
                    'Object_Variable_Imag',
                    dtype=tf.float32,
                    initializer=np.float32(np.imag(self.obj)))
                #set reuse flag to True
                tf.get_variable_scope().reuse_variables()
                #just an assertion!
                assert tf.get_variable_scope().reuse == True

            # assign training variables
            self.tf_lambda_tv = tf.placeholder(tf.float32, [])
            self.tf_eps = tf.placeholder(tf.float32, [])
            self.tf_meas = tf.placeholder(dtype=tf.complex64,
                                          shape=self.mysize_old)
            self.tf_learningrate = tf.placeholder(tf.float32, [])

        elif (self.is_optimization == 0):
            # Variables of the computational graph
            self.TF_obj = tf.constant(np.real(self.obj),
                                      dtype=tf.float32,
                                      name='Object_const')
            self.TF_obj_absorption = tf.constant(np.imag(self.obj),
                                                 dtype=tf.float32,
                                                 name='Object_const')

        elif (self.is_optimization == -1):
            # THis is for the case that we want to train the resnet
            self.tf_meas = tf.placeholder(dtype=tf.complex64,
                                          shape=self.mysize_old)
            # in case one wants to use this as a fwd-model for an inverse problem

            #self.TF_obj = tf.Variable(np.real(self.obj), dtype=tf.float32, name='Object_Variable')
            #self.TF_obj_absorption = tf.Variable(np.imag(self.obj), dtype=tf.float32, name='Object_Variable')
            self.TF_obj = tf.placeholder(dtype=tf.float32,
                                         shape=self.obj.shape,
                                         name='Object_Variable_Real')
            self.TF_obj_absorption = tf.placeholder(
                dtype=tf.float32,
                shape=self.obj.shape,
                name='Object_Variable_Imag')

            # assign training variables
            self.tf_lambda_tv = tf.placeholder(tf.float32, [])
            self.tf_eps = tf.placeholder(tf.float32, [])

            self.tf_learningrate = tf.placeholder(tf.float32, [])

        ## Establish normalized coordinates.
        #-----------------------------------------
        vxx = tf_helper.xx(
            (self.mysize[1], self.mysize[2]),
            'freq') * self.lambdaM * self.nEmbb / (self.dx * self.NAo)
        # normalized optical coordinates in X
        vyy = tf_helper.yy(
            (self.mysize[1], self.mysize[2]),
            'freq') * self.lambdaM * self.nEmbb / (self.dy * self.NAo)
        # normalized optical coordinates in Y

        # AbbeLimit=lambda0/NAo;  # Rainer's Method
        # RelFreq = rr(mysize,'freq')*AbbeLimit/dx;  # Is not generally right (dx and dy)
        self.RelFreq = np.sqrt(tf_helper.abssqr(vxx) + tf_helper.abssqr(vyy))
        # spanns the frequency grid of normalized pupil coordinates
        self.Po = self.RelFreq < 1.0
        # Create the pupil of the objective lens

        # Precomputing the first 9 zernike coefficients
        self.nzernikes = np.squeeze(self.zernikefactors.shape)
        self.myzernikes = np.zeros(
            (self.Po.shape[0], self.Po.shape[1],
             self.nzernikes)) + 1j * np.zeros(
                 (self.Po.shape[0], self.Po.shape[1], self.nzernikes))
        r, theta = zern.cart2pol(vxx, vyy)
        for i in range(0, self.nzernikes):
            self.myzernikes[:, :, i] = np.fft.fftshift(
                zern.zernike(r, theta, i + 1,
                             norm=False))  # or 8 in X-direction

        # eventually introduce a phase factor to approximate the experimental data better
        self.Po = np.fft.fftshift(
            self.Po
        )  # Need to shift it before using as a low-pass filter    Po=np.ones((np.shape(Po)))
        print('----------> Be aware: We are taking aberrations into account!')
        # Assuming: System has coma along X-direction
        self.myaberration = np.sum(self.zernikefactors * self.myzernikes,
                                   axis=2)
        self.Po = 1. * self.Po  #*np.exp(1j*self.myaberration)

        # Prepare the normalized spatial-frequency grid.
        self.S = self.NAc / self.NAo
        # Coherence factor
        self.Ic = self.RelFreq <= self.S

        # Take Darkfield into account
        if hasattr(self, 'NAci'):
            if self.NAci != None and self.NAci > 0:
                #print('I detected a darkfield illumination aperture!')
                self.S_o = self.NAc / self.NAo
                # Coherence factor
                self.S_i = self.NAci / self.NAo
                # Coherence factor
                self.Ic = (1. * (self.RelFreq < self.S_o) * 1. *
                           (self.RelFreq > self.S_i)
                           ) > 0  # Create the pupil of the condenser plane

        # weigh the illumination source with some cos^2 intensity weight?!
        myIntensityFactor = 70
        self.Ic_map = np.cos((myIntensityFactor * tf_helper.xx(
            (self.Nx, self.Ny), mode='freq')**2 +
                              myIntensityFactor * tf_helper.yy(
                                  (self.Nx, self.Ny), mode='freq')**2))**2
        self.Ic = self.Ic * np.sqrt(
            self.Ic_map
        )  # weight the intensity in the condenser aperture, unlikely to be uniform
        # print('--------> ATTENTION! - We are not weighing the Intensity int the illu-pupil!')

        # Shift the pupil in X-direction (optical missalignment)
        if hasattr(self, 'shiftIcX'):
            if self.shiftIcX != None:
                if (is_padding): self.shiftIcX = self.shiftIcX * 2
                print('Shifting the illumination in X by: ' +
                      str(self.shiftIcX) + ' Pixel')
                if (0):
                    self.Ic = np.roll(self.Ic, self.shiftIcX, axis=1)
                elif (1):
                    tform = AffineTransform(scale=(1, 1),
                                            rotation=0,
                                            shear=0,
                                            translation=(self.shiftIcX, 0))
                    self.Ic = warp(self.Ic,
                                   tform.inverse,
                                   output_shape=self.Ic.shape)
                elif (0):
                    # We apply a phase-factor to shift the source in realspace - so make it trainable
                    self.shift_xx = tf_helper.xx(
                        (self.mysize[1], self.mysize[2]), 'freq')
                    self.Ic = np.abs(
                        np.fft.ifft2(
                            np.fft.fft2(self.Ic) *
                            np.exp(1j * 2 * np.pi * self.shift_xx *
                                   self.shiftIcX)))

        # Shift the pupil in Y-direction (optical missalignment)
        if hasattr(self, 'shiftIcY'):
            if self.shiftIcY != None:
                if (is_padding): self.shiftIcY = self.shiftIcY * 2
                print('Shifting the illumination in Y by: ' +
                      str(self.shiftIcY) + ' Pixel')
                if (0):
                    self.Ic = np.roll(self.Ic, self.shiftIcY, axis=0)
                elif (1):
                    tform = AffineTransform(scale=(1, 1),
                                            rotation=0,
                                            shear=0,
                                            translation=(0, self.shiftIcY))
                    self.Ic = warp(self.Ic,
                                   tform.inverse,
                                   output_shape=self.Ic.shape)
                elif (0):
                    # We apply a phase-factor to shift the source in realspace - so make it trainable
                    self.shift_yy = tf_helper.yy(
                        (self.mysize[1], self.mysize[2]), 'freq')
                    self.Ic = np.abs(
                        np.fft.ifft2(
                            np.fft.fft2(self.Ic) *
                            np.exp(1j * self.shift_yy * self.shiftIcY)))

        ## Forward propagator  (Ewald sphere based) DO NOT USE NORMALIZED COORDINATES HERE
        self.kxysqr = (tf_helper.abssqr(
            tf_helper.xx((self.mysize[1], self.mysize[2]), 'freq') /
            self.dx) + tf_helper.abssqr(
                tf_helper.yy(
                    (self.mysize[1], self.mysize[2]), 'freq') / self.dy)) + 0j
        self.k0 = 1 / self.lambdaM
        self.kzsqr = tf_helper.abssqr(self.k0) - self.kxysqr
        self.kz = np.sqrt(self.kzsqr)
        self.kz[self.kzsqr < 0] = 0
        self.dphi = 2 * np.pi * self.kz * self.dz
        # exp(1i*kz*dz) would be the propagator for one slice

        ## Get a list of vector coordinates corresponding to the pixels in the mask
        xfreq = tf_helper.xx((self.mysize[1], self.mysize[2]), 'freq')
        yfreq = tf_helper.yy((self.mysize[1], self.mysize[2]), 'freq')
        self.Nc = np.sum(self.Ic > 0)
        print('Number of Illumination Angles / Plane waves: ' + str(self.Nc))

        # Calculate the computatonal grid/sampling
        self.kxcoord = np.reshape(xfreq[self.Ic > 0], [1, 1, 1, self.Nc])
        # NA-positions in condenser aperture plane in x-direction
        self.kycoord = np.reshape(yfreq[self.Ic > 0], [1, 1, 1, self.Nc])
        # NA-positions in condenser aperture plane in y-direction
        self.RefrCos = np.reshape(self.k0 / self.kz[self.Ic > 0],
                                  [1, 1, 1, self.Nc])
        # 1/cosine used for the application of the refractive index steps to acount for longer OPD in medium under an oblique illumination angle

        ## Generate the illumination amplitudes
        self.intensityweights = self.Ic[self.Ic > 0]
        self.A_input = self.intensityweights * np.exp(
            (2 * np.pi * 1j) *
            (self.kxcoord *
             tf_helper.repmat4d(tf_helper.xx(
                 (self.mysize[1], self.mysize[2])), self.Nc) + self.kycoord *
             tf_helper.repmat4d(tf_helper.yy(
                 (self.mysize[1], self.mysize[2])), self.Nc))
        )  # Corresponds to a plane wave under many oblique illumination angles - bfxfun

        ## propagate field to z-stack and sum over all illumination angles
        self.Alldphi = -np.reshape(np.arange(
            0, self.mysize[0], 1), [1, 1, self.mysize[0]]) * np.repeat(
                np.fft.fftshift(self.dphi)[:, :, np.newaxis],
                self.mysize[0],
                axis=2)

        # Ordinary backpropagation. This is NOT what we are interested in:
        self.myAllSlicePropagator = np.transpose(
            np.exp(1j * self.Alldphi) *
            (np.repeat(np.fft.fftshift(self.dphi)[:, :, np.newaxis],
                       self.mysize[0],
                       axis=2) > 0), [2, 0, 1])
    def computemodel(self):
        ''' Perform Multiple Scattering here
        1.) Multiple Scattering is perfomed by slice-wise propagation the E-Field throught the sample
        2.) Each Field has to be backprojected to the BFP
        3.) Last step is to ceate a focus-stack and sum over all angles

        This is done for all illumination angles (coming from illumination NA
        simultaneasly)'''

        print("Buildup Q-PHASE Model ")
        ###### make sure, that the first dimension is "batch"-size; in this case it is the illumination number
        # @beniroquai It's common to have to batch dimensions first and not last.!!!!!
        # the following loop propagates the field sequentially to all different Z-planes

        ## propagate the field through the entire object for all angles simultaneously
        A_prop = np.transpose(
            self.A_input,
            [3, 0, 1, 2
             ])  # ??????? what the hack is happening with transpose?!

        myprop = np.exp(1j * self.dphi) * (self.dphi > 0)
        # excludes the near field components in each step
        myprop = tf_helper.repmat4d(myprop, self.Nc)
        myprop = np.transpose(
            myprop, [3, 0, 1, 2
                     ])  # ??????? what the hack is happening with transpose?!

        RefrEffect = 1j * self.dz * self.k0 * self.RefrCos
        # Precalculate the oblique effect on OPD to speed it up
        RefrEffect = np.transpose(RefrEffect, [3, 0, 1, 2])

        # for now orientate the dimensions as (alpha_illu, x, y, z) - because tensorflow takes the first dimension as batch size
        with tf.name_scope('Variable_assignment'):
            self.TF_A_input = tf.constant(A_prop, dtype=tf.complex64)
            self.TF_RefrEffect = tf.reshape(
                tf.constant(RefrEffect, dtype=tf.complex64), [self.Nc, 1, 1])
            self.TF_myprop = tf.squeeze(tf.constant(myprop,
                                                    dtype=tf.complex64))
            self.TF_Po = tf.cast(tf.constant(self.Po), tf.complex64)
            self.TF_Zernikes = tf.constant(self.myzernikes, dtype=tf.float32)

            if (self.is_optimization_psf):
                self.TF_zernikefactors = tf.Variable(self.zernikefactors,
                                                     dtype=tf.float32,
                                                     name='var_zernikes')
            else:
                self.TF_zernikefactors = tf.constant(self.zernikefactors,
                                                     dtype=tf.float32,
                                                     name='const_zernikes')

        # TODO: Introduce the averraged RI along Z - MWeigert

        self.TF_A_prop = tf.squeeze(self.TF_A_input)
        self.U_z_list = []

        # Initiliaze memory
        self.allInt = 0
        self.allSumAmp = 0
        self.TF_allSumAmp = tf.zeros([self.mysize[0], self.Nx, self.Ny],
                                     dtype=tf.complex64)

        self.tf_iterator = tf.Variable(1)
        # simulate multiple scattering through object
        with tf.name_scope('Fwd_Propagate'):
            for pz in range(0, self.mysize[0]):
                self.tf_iterator += self.tf_iterator
                #self.TF_A_prop = tf.Print(self.TF_A_prop, [self.tf_iterator], 'Prpagation step: ')
                with tf.name_scope('Refract'):
                    TF_f_phase = tf.cast(self.TF_obj_phase_do[pz, :, :],
                                         tf.complex64)
                    self.TF_f = tf.exp(1j * self.TF_RefrEffect * TF_f_phase)
                    self.TF_A_prop = self.TF_A_prop * self.TF_f  # refraction step

                with tf.name_scope('Propagate'):
                    self.TF_A_prop = tf_helper.my_ift2d(
                        tf_helper.my_ft2d(self.TF_A_prop) *
                        self.TF_myprop)  # diffraction step

        # Bring the slice back to focus - does this make any sense?!
        print('----------> Bringing field back to focus')
        self.TF_A_prop = tf_helper.my_ift2d(
            tf_helper.my_ft2d(self.TF_A_prop) *
            (-self.Nz / 2 * self.TF_myprop))  # diffraction step

        # in a final step limit this to the detection NA:
        self.TF_Po_aberr = tf.exp(1j * tf.cast(
            tf.reduce_sum(self.TF_zernikefactors * self.TF_Zernikes, axis=2),
            tf.complex64)) * self.TF_Po
        self.TF_A_prop = tf_helper.my_ift2d(
            tf_helper.my_ft2d(self.TF_A_prop) * self.TF_Po_aberr)

        self.TF_myAllSlicePropagator = tf.constant(self.myAllSlicePropagator,
                                                   dtype=tf.complex64)
        self.kzcoord = np.reshape(self.kz[self.Ic > 0], [1, 1, 1, self.Nc])

        # create Z-Stack by backpropagating Information in BFP to Z-Position
        # self.mid3D = ([np.int(np.ceil(self.A_input.shape[0] / 2) + 1), np.int(np.ceil(self.A_input.shape[1] / 2) + 1), np.int(np.ceil(self.mysize[0] / 2) + 1)])
        self.mid3D = ([
            np.int(self.mysize[0] // 2),
            np.int(self.A_input.shape[0] // 2),
            np.int(self.A_input.shape[1] // 2)
        ])

        with tf.name_scope('Back_Propagate'):
            for pillu in range(0, self.Nc):
                with tf.name_scope('Back_Propagate_Step'):
                    with tf.name_scope('Adjust'):
                        #    fprintf('BackpropaAngle no: #d\n',pillu);
                        OneAmp = tf.expand_dims(self.TF_A_prop[pillu, :, :], 0)

                        # Fancy backpropagation assuming what would be measured if the sample was moved under oblique illumination:
                        # The trick is: First use conceptually the normal way
                        # and then apply the XYZ shift using the Fourier shift theorem (corresponds to physically shifting the object volume, scattered field stays the same):
                        self.TF_AdjustKXY = tf.squeeze(
                            tf.conj(self.TF_A_input[pillu, :, :, ])
                        )  # tf.transpose(tf.conj(TF_A_input[pillu, :,:,]), [2, 1, 0]) # Maybe a bit of a dirty hack, but we first need to shift the zero coordinate to the center
                        self.TF_AdjustKZ = tf.transpose(
                            tf.constant(np.exp(
                                2 * np.pi * 1j * self.dz *
                                np.reshape(np.arange(0, self.mysize[0], 1),
                                           [1, 1, self.mysize[0]]) *
                                self.kzcoord[:, :, :, pillu]),
                                        dtype=tf.complex64), [2, 1, 0])
                        self.TF_allAmp = tf_helper.my_ift2d(
                            tf_helper.my_ft2d(OneAmp) *
                            self.TF_myAllSlicePropagator
                        ) * self.TF_AdjustKZ * self.TF_AdjustKXY  # * (TF_AdjustKZ);  # 2x bfxfun.  Propagates a single amplitude pattern back to the whole stack
                        self.TF_allAmp = self.TF_allAmp * tf.exp(
                            1j * tf.cast(
                                tf.angle(self.TF_allAmp[
                                    self.mid3D[0], self.mid3D[1],
                                    self.mid3D[2]]), tf.complex64)
                        )  # Global Phases need to be adjusted at this step!  Use the zero frequency

                    if (0):
                        with tf.name_scope('Propagate'):
                            self.TF_allAmp_3dft = tf.fft3d(
                                tf.expand_dims(self.TF_allAmp, axis=0))
                            self.TF_allAmp = self.TF_allAmp * tf.exp(
                                -1j * tf.cast(
                                    tf.angle(self.TF_allAmp_3dft[
                                        self.mid3D[2], self.mid3D[1],
                                        self.mid3D[0]]), tf.complex64))
                            # Global Phases need to be adjusted at this step!  Use the zero frequency
                    #print('Global phase: '+str(tf.exp(1j*tf.cast(tf.angle(self.TF_allAmp[self.mid3D[0],self.mid3D[1],self.mid3D[2]]), tf.complex64).eval()))

                    with tf.name_scope(
                            'Sum_Amps'
                    ):  # Normalize amplitude by condenser intensity
                        self.TF_allSumAmp = self.TF_allSumAmp + self.TF_allAmp  #/ self.intensityweights[pillu];  # Superpose the Amplitudes
                    # print('Current illumination angle # is: ' + str(pillu))

        # Normalize the image such that the values do not depend on the fineness of
        # the source grid.
        self.TF_allSumAmp = self.TF_allSumAmp / self.Nc  #/tf.cast(tf.reduce_max(tf.abs(self.TF_allSumAmp)), tf.complex64)

        # Following is the normalization according to Martin's book. It ensures
        # that a transparent specimen is imaged with unit intensity.
        # normfactor=abs(Po).^2.*abs(Ic); We do not use it, because it leads to
        # divide by zero for dark-field system. Instead, through normalizations
        # perfomed above, we ensure that image of a point under matched
        # illumination is unity. The brightness of all the other configurations is
        # relative to this benchmark.
        #

        # negate padding
        if self.is_padding:
            self.TF_allSumAmp = self.TF_allSumAmp[:, self.Nx // 2 -
                                                  self.Nx // 4:self.Nx // 2 +
                                                  self.Nx // 4, self.Ny // 2 -
                                                  self.Ny // 4:self.Ny // 2 +
                                                  self.Ny // 4]

        return self.TF_allSumAmp
    def computemodel(self, is_resnet=False, is_forcepos=False):
        ''' Perform Multiple Scattering here
        1.) Multiple Scattering is perfomed by slice-wise propagation the E-Field throught the sample
        2.) Each Field has to be backprojected to the BFP
        3.) Last step is to ceate a focus-stack and sum over all angles
 
        This is done for all illumination angles (coming from illumination NA
        simultaneasly)'''
        self.is_resnet = is_resnet
        self.is_forcepos = is_forcepos

        print("Buildup Q-PHASE Model ")
        ###### make sure, that the first dimension is "batch"-size; in this case it is the illumination number
        # @beniroquai It's common to have to batch dimensions first and not last.!!!!!
        # the following loop propagates the field sequentially to all different Z-planes

        ## propagate the field through the entire object for all angles simultaneously
        A_prop = np.transpose(
            self.A_input,
            [3, 0, 1, 2
             ])  # ??????? what the hack is happening with transpose?!

        if (self.is_tomo):
            print('Experimentally using the tomographic scheme!')
            A_prop = np.conj(A_prop)

        myprop = np.exp(1j * self.dphi) * (self.dphi > 0)
        # excludes the near field components in each step
        myprop = tf_helper.repmat4d(np.fft.fftshift(myprop), self.Nc)
        myprop = np.transpose(
            myprop, [3, 0, 1, 2
                     ])  # ??????? what the hack is happening with transpose?!

        print('--------> ATTENTION: I added a pi factor - is this correct?!')
        RefrEffect = np.squeeze(1j * np.pi * self.dz * self.k0 * self.RefrCos)
        # Precalculate the oblique effect on OPD to speed it up

        # for now orientate the dimensions as (alpha_illu, x, y, z) - because tensorflow takes the first dimension as batch size
        with tf.name_scope('Variable_assignment'):
            self.TF_A_input = tf.constant(A_prop, dtype=tf.complex64)
            self.TF_RefrEffect = tf.reshape(
                tf.constant(RefrEffect, dtype=tf.complex64), [self.Nc, 1, 1])
            self.TF_myprop = tf.constant(np.squeeze(myprop),
                                         dtype=tf.complex64)
            self.TF_Po = tf.cast(tf.constant(self.Po), tf.complex64)
            self.TF_Zernikes = tf.constant(self.myzernikes, dtype=tf.float32)
            self.TF_myAllSlicePropagator = tf.constant(
                self.myAllSlicePropagator, dtype=tf.complex64)

            # Only update those Factors which are really necesarry (e.g. Defocus is not very likely!)
            self.TF_zernikefactors = tf.Variable(self.zernikefactors,
                                                 dtype=tf.float32,
                                                 name='var_zernikes')
            #indexes = tf.constant([[4], [5], [6], [7], [8], [9]])
            indexes = tf.cast(tf.where(tf.constant(self.zernikemask) > 0),
                              tf.int32)
            updates = tf.gather_nd(self.TF_zernikefactors, indexes)
            # Take slice
            # Build tensor with "filtered" gradient
            part_X = tf.scatter_nd(indexes, updates,
                                   tf.shape(self.TF_zernikefactors))
            self.TF_zernikefactors_filtered = part_X + tf.stop_gradient(
                -part_X + self.TF_zernikefactors)

        # TODO: Introduce the averraged RI along Z - MWeigert
        self.TF_A_prop = tf.squeeze(self.TF_A_input)
        self.U_z_list = []

        # Initiliaze memory
        self.allSumAmp = 0
        self.TF_allSumAmp = tf.zeros([self.mysize[0], self.Nx, self.Ny],
                                     dtype=tf.complex64)
        ''' Eventually add a RESNET-layer between RI and Microscope to correct for model discrepancy?'''
        if (self.is_resnet):
            with tf.variable_scope('res_real', reuse=False):
                TF_real_3D = self.residual_block(
                    tf.expand_dims(tf.expand_dims(self.TF_obj, 3), 0), 1, True)
                TF_real_3D = tf.squeeze(TF_real_3D)
            with tf.variable_scope('res_imag', reuse=False):
                TF_imag_3D = self.residual_block(
                    tf.expand_dims(tf.expand_dims(self.TF_obj_absorption, 3),
                                   0), 1, True)
                TF_imag_3D = tf.squeeze(TF_imag_3D)
        else:
            TF_real_3D = self.TF_obj
            TF_imag_3D = self.TF_obj_absorption

        # wrapper for force-positivity on the RI-instead of penalizing it
        if (self.is_forcepos):
            print('----> ATTENTION: We add the PreMonotonicPos')
            TF_real_3D = tf_reg.PreMonotonicPos(TF_real_3D)
            TF_imag_3D = tf_reg.PreMonotonicPos(TF_imag_3D)

        TF_A_prop_illu_list = []
        # simulate multiple scattering through object
        with tf.name_scope('Fwd_Propagate'):
            #print('---------ATTENTION: We are inverting the RI!')

            # First Iterate over all illumination angles
            for pillu in range(0, self.Nc):
                # Second iterate over all z-slices
                TF_A_prop_z_tmp = self.TF_A_prop[pillu, :, :]
                for pz in range(0, self.mysize[0]):
                    if (self.is_padding):
                        tf_paddings = tf.constant([[
                            self.mysize_old[1] // 2, self.mysize_old[1] // 2
                        ], [self.mysize_old[2] // 2, self.mysize_old[2] // 2]])
                        TF_real = tf.pad(TF_real_3D[-pz, :, :],
                                         tf_paddings,
                                         mode='CONSTANT',
                                         name='TF_obj_real_pad')
                        TF_imag = tf.pad(TF_imag_3D[-pz, :, :],
                                         tf_paddings,
                                         mode='CONSTANT',
                                         name='TF_obj_imag_pad')
                    else:
                        TF_real = (TF_real_3D[-pz, :, :])
                        TF_imag = (TF_imag_3D[-pz, :, :])

                    self.TF_f = tf.exp(self.TF_RefrEffect[pillu, :, :] *
                                       tf.complex(TF_real, TF_imag))
                    #self.TF_A_prop = tf.Print(self.TF_A_prop, [self.tf_iterator], 'Prpagation step: ')
                    with tf.name_scope('Refract'):
                        # beware the "i" is in TF_RefrEffect already!
                        TF_A_prop_z_tmp = TF_A_prop_z_tmp * self.TF_f  # refraction step

                    with tf.name_scope('Propagate'):
                        TF_A_prop_z_tmp = tf.ifft2d(
                            tf.fft2d(TF_A_prop_z_tmp) *
                            self.TF_myprop[pillu, :, :])  # diffraction step

                TF_A_prop_illu_list.append(TF_A_prop_z_tmp)
            self.TF_A_prop = tf.stack(TF_A_prop_illu_list)

        # in a final step limit this to the detection NA:
        self.TF_Po_aberr = tf.exp(1j * tf.cast(
            tf.reduce_sum(self.TF_zernikefactors_filtered * self.TF_Zernikes,
                          axis=2), tf.complex64)) * self.TF_Po
        self.TF_A_prop = tf.ifft2d(
            tf.fft2d(self.TF_A_prop) * self.TF_Po * self.TF_Po_aberr)

        # Experimenting with pseudo tomographic data?
        if self.is_tomo:
            print('Only Experimental! Tomographic data?!')
            # Bring the slice back to focus - does this make any sense?!
            print('----------> Bringing field back to focus')
            TF_centerprop = tf.exp(
                -1j *
                tf.cast(self.Nz / 2 * tf.angle(self.TF_myprop), tf.complex64))
            self.TF_A_prop = tf.ifft2d(
                tf.fft2d(self.TF_A_prop) * TF_centerprop)  # diffraction step
            return self.TF_A_prop

        # create Z-Stack by backpropagating Information in BFP to Z-Position
        self.kzcoord = np.reshape(self.kz[self.Ic > 0], [1, 1, 1, self.Nc])

        # self.mid3D = ([np.int(np.ceil(self.A_input.shape[0] / 2) + 1), np.int(np.ceil(self.A_input.shape[1] / 2) + 1), np.int(np.ceil(self.mysize[0] / 2) + 1)])
        self.mid3D = ([
            np.int(self.mysize[0] // 2),
            np.int(self.A_input.shape[0] // 2),
            np.int(self.A_input.shape[1] // 2)
        ])

        with tf.name_scope('Back_Propagate'):
            print('----------> ATTENTION: PHASE SCRAMBLING!')
            for pillu in range(0, self.Nc):
                TF_allAmp_list = []
                #    fprintf('BackpropaAngle no: #d\n',pillu);
                OneAmp = tf.expand_dims(self.TF_A_prop[pillu, :, :], 0)
                # Fancy backpropagation assuming what would be measured if the sample was moved under oblique illumination:
                # The trick is: First use conceptually the normal way
                # and then apply the XYZ shift using the Fourier shift theorem (corresponds to physically shifting the object volume, scattered field stays the same):
                self.TF_AdjustKXY = tf.squeeze(
                    tf.conj(self.TF_A_input[pillu, :, :, ])
                )  # Maybe a bit of a dirty hack, but we first need to shift the zero coordinate to the center
                self.TF_AdjustKZ = tf.cast(
                    tf.transpose(
                        np.exp(2 * np.pi * 1j * self.dz * np.reshape(
                            np.arange(
                                0, self.mysize[0], 1
                            ),  # We want to start from first Z-slice then go to last which faces the objective lens
                            [1, 1, self.mysize[0]]) *
                               self.kzcoord[:, :, :, pillu]),
                        [2, 1, 0]),
                    tf.complex64)

                for pz in range(0, self.Nz):
                    with tf.name_scope('Back_Propagate_Step'):
                        TF_allAmp_list.append(
                            tf.squeeze(
                                tf.ifft2d(
                                    tf.fft2d(OneAmp) *
                                    self.TF_myAllSlicePropagator[pz, :, :])) *
                            self.TF_AdjustKZ[pz, :, :] *
                            self.TF_AdjustKXY[:, :]
                        )  # * (TF_AdjustKZ);  # 2x bfxfun.  Propagates a single amplitude pattern back to the whole stack
                        #tf_global_phase = tf.cast(tf.angle(self.TF_allAmp[self.mid3D[0],self.mid3D[1],self.mid3D[2]]), tf.complex64)
                        #tf_global_phase = tf.cast(np.random.randn(1)*np.pi,tf.complex64)
                        #self.TF_allAmp = self.TF_allAmp * tf.exp(1j*tf_global_phase) # Global Phases need to be adjusted at this step!  Use the zero frequency

                TF_allAmp = tf.stack(TF_allAmp_list)

                with tf.name_scope('Adjust_Global_Phase'):
                    self.TF_allAmp_3dft = tf.fft3d(
                        tf.expand_dims(TF_allAmp, axis=0))
                    tf_global_phase = tf.angle(
                        self.TF_allAmp_3dft[0, 0, 0, 0]
                    )  #tf.angle(self.TF_allAmp_3dft[0, self.mid3D[2], self.mid3D[1], self.mid3D[0]])
                    tf_global_phase = tf.cast(tf_global_phase, tf.complex64)

                    TF_allAmp = TF_allAmp * tf.exp(-1j * tf_global_phase)
                    # Global Phases need to be adjusted at this step!  Use the zero frequency
                    #print('Global phase: '+str(tf.exp(1j*tf.cast(tf.angle(self.TF_allAmp[self.mid3D[0],self.mid3D[1],self.mid3D[2]]), tf.complex64).eval()))

                with tf.name_scope(
                        'Sum_Amps'
                ):  # Normalize amplitude by condenser intensity
                    self.TF_allSumAmp += TF_allAmp
                    # Superpose the Amplitudes

                    # print('Current illumination angle # is: ' + str(pillu))

        # Normalize the image such that the values do not depend on the fineness of
        # the source grid.
        self.TF_allSumAmp = self.TF_allSumAmp / self.Nc  #/tf.cast(tf.reduce_max(tf.abs(self.TF_allSumAmp)), tf.complex64)
        # Following is the normalization according to Martin's book. It ensures
        # that a transparent specimen is imaged with unit intensity.
        # normfactor=abs(Po).^2.*abs(Ic); We do not use it, because it leads to
        # divide by zero for dark-field system. Instead, through normalizations
        # perfomed above, we ensure that image of a point under matched
        # illumination is unity. The brightness of all the other configurations is
        # relative to this benchmark.
        #

        # negate padding
        if self.is_padding:
            self.TF_allSumAmp = self.TF_allSumAmp[:, self.Nx // 2 -
                                                  self.Nx // 4:self.Nx // 2 +
                                                  self.Nx // 4, self.Ny // 2 -
                                                  self.Ny // 4:self.Ny // 2 +
                                                  self.Ny // 4]

        return self.TF_allSumAmp