Exemplo n.º 1
0
    def tf_FSPAS_FFT(self, field, wlength, z, dx, dy, ridx,theta_max):
    
        """
        Angular Spectrum Propagation of Coherent Wave Fields
        with optional filtering
        
        INPUTS : 
            U, wave-field in space domain
            wlenght : MULTI-wavelengthS in the optical wave
            z : distance of propagation
            dx,dy : sampling intervals in space
            M,N : Size of simulation window
            theta0 : Optional BAndwidth Limitation in DEGREES 
                (if no filtering is desired, only EVANESCENT WAVE IS FILTERED)
        
        OUTPUT : 
            output, propagated wave-field in space domain
        
        """
        
        B, C, M, N = self.B, self.C, self.M, self.N
        output_shape = tf.stack([B, C, tf.to_int32(M),tf.to_int32(N)])
        wlengtheff = wlength/ridx
        dfx = 1/dx/M
        dfy = 1/dy/N
        fx = tf.expand_dims((tf.range(M,dtype=tf.float32)-(M)/2)*dfx,-1)
        fy = tf.expand_dims((tf.range(N,dtype=tf.float32)-(N)/2)*dfy,0)
        fx2 = tf.matmul(fx**2,tf.ones((1,N),dtype=tf.float32))
        fy2 = tf.matmul(tf.ones((M,1),dtype=tf.float32),fy**2)
        
        wlengtheff = tf.expand_dims(tf.expand_dims(wlengtheff,-1),-1)
        # Diffraction limit
        f0 = 1.0/wlengtheff
        Qd = tf.to_float(tf.less(tf.expand_dims((fx2+fy2),0),(f0**2)))

        # Prop Anti-aliasing
        Qbw = self.adaptive_bandlimit(fx,fy,z,dfx,dfy,wlengtheff,theta_max) #!!! CHECK THIS

        Q = Qd*Qbw
            
        W = Q*tf.expand_dims((fx2+fy2),0)*(wlengtheff**2)
        phase_term_sqrt = (tf.ones((C,tf.to_int32(M),tf.to_int32(N)))-W)**(0.5)
        Hphase = 2*np.pi/wlengtheff*z*phase_term_sqrt
        HFSP = tf.complex(Q*tf.cos(Hphase),Q*tf.sin(Hphase))
        ASpectrum = tf.fft2d(field)
        ASpectrum = self.tf_fft_shift_2d(ASpectrum)    
        ASpectrum_z = self.tf_ifft_shift_2d(tf.multiply(HFSP,ASpectrum))
        output = tf.ifft2d(ASpectrum_z)
        
        Qd = tf.to_complex64(Qd)
        Q = tf.to_complex64(Q)
        unitary_prop_cnst = tf.divide(tf.reduce_sum(tf.square(tf.abs(ASpectrum)),axis=[2,3],keepdims=True),tf.reduce_sum(tf.square(tf.abs(ASpectrum*Qd)),axis=[2,3],keepdims=True))
        unitary_cnst = tf.complex(tf.sqrt(unitary_prop_cnst),tf.zeros_like(unitary_prop_cnst))
        output = tf.multiply(output,unitary_cnst)
        output = tf.slice(output, tf.stack([0, 0, 0, 0]), output_shape)
        scattered_pwr = tf.reduce_sum(tf.square(tf.abs(unitary_cnst*ASpectrum*Qd*(Qd-Q))),axis=[2,3],keepdims=True)/M/N
        return output, scattered_pwr
Exemplo n.º 2
0
def tf_SE3_to_se3(SE3):
    '''
    input : SE3 [-1, 4, 4]
    output: se3 [-1,6]
    '''
    # do not use tf.linalg.logm() because it doesn't have gradient graph

    R = SE3[:, 0:3, 0:3]  # [N,3,3]
    T = SE3[:, 0:3, 3]  # [N, 3]
    thetha = tf.expand_dims(tf.acos((tf.trace(R) - 1) / 2), 1)
    thetha = tf.expand_dims(thetha, 2) + 1e-5

    R = tf.to_complex64(R)
    w_hat = tf.linalg.logm(R)  # [N,3,3]
    w_hat = tf.to_float(w_hat)
    w = tf_vee(w_hat)  # [N,3]

    v = tf.matmul(
        tf.eye(3) - (1 / 2 * (w_hat)) +
        (((1 / (thetha * thetha)) *
          (1 - ((thetha * tf.sin(thetha)) /
                (2 * (1 - tf.cos(thetha)))))) * tf.matmul(w_hat, w_hat)),
        tf.expand_dims(T, 2))
    v = v[:, :, 0]
    se3 = tf.concat([v, w], axis=1)
    return se3
Exemplo n.º 3
0
 def to_complex(self, x):
     if self.dtype(x) in (np.complex64, np.complex128):
         return x
     if self.dtype(x) == np.float64:
         return tf.to_complex128(x)
     else:
         return tf.to_complex64(x)
Exemplo n.º 4
0
    def call(self, x):
        """
        This method must be defined for any custom layer, it is where the calculations are done.   
        
        x: a tensor representing the inputs to the layer. This is passed automatically by tensorflow. 
        """

        # add zero voltage parameters to the input
        H = tf.to_complex64(x + self.H0)

        # retreive the complex opertor
        complex_operator = tf.constant(self.complex_operator,
                                       dtype=tf.complex64)

        # add two extra dimensions for batch and time
        complex_operator = tf.expand_dims(complex_operator, 0)
        complex_operator = tf.expand_dims(complex_operator, 0)

        # construct a tensor in the form of a row vector whose elements are [d1,d2,1,1], where d1 and d2 correspond to the
        # number of examples and number of time steps of the input
        temp_shape = tf.concat(
            [tf.shape(x)[0:2],
             tf.constant(np.array([1, 1], dtype=np.int32))], 0)

        # repeat the input ket colmun along the batch and time dimensions
        complex_operator = tf.tile(complex_operator, temp_shape)

        # apply the complex operator to convert lower traingular part into pure imaginary
        H = tf.multiply(H, complex_operator)

        # convert to symmetric matrix by doing H+H' [permute index 3 and 2]
        H = tf.add(H, tf.transpose(H, [0, 1, 3, 2], conjugate=True))

        return H
def stft_tf(wav, win_length, hop_length, n_fft, window='hann', mode='REFLECT'):
    '''
    implement stft in tensorflow
    the output is same as librosa.stft with center=True in 10*-6 error
    link: https://github.com/zhang-wy15/stft_from_librosa_to_tensorflow
    '''
    # By default, use the entire frame
    if win_length is None:
        win_length = n_fft

    # Set the default hop, if it's not already specified
    if hop_length is None:
        hop_length = int(win_length // 4)

    window = scipy.signal.get_window(window, win_length, fftbins=True)

    # Pad the window out to n_fft size
    window = np.pad(window,
                    ((n_fft - win_length) // 2, (n_fft - win_length) // 2),
                    mode='constant',
                    constant_values=(0, 0))

    # Reshape so that the window can be broadcast
    # We don't need this
    # window = window.reshape((-1,1))

    # Pad the time series so that frames are centered
    center = True
    if center:
        wav = tf.pad(wav, [[n_fft // 2, n_fft // 2]], mode=mode)

    # Window the time series.
    f = tf.contrib.signal.frame(wav, n_fft, hop_length, pad_end=False)

    # fft method 1: divide block and caculate fft separately
    # fft method 2: whole frame to tf.spectral.fft
    # result are same, but method 2 is faster

    # method 1:
    '''
    linear = tf.zeros((f.shape[0],int(1 + n_fft // 2)))
    MAX_MEM_BLOCK = 2**8 * 2**10
    itemsieze = 8
    n_columns = int(MAX_MEM_BLOCK / (int(1 + n_fft // 2) * itemsieze))
    for bl_s in range(0, linear.shape[0], n_columns):
        bl_t = min(bl_s + n_columns, linear.shape[0])
        temp = tf.spectral.fft(tf.to_complex64(f[bl_s:bl_t,:] * window))[:,:linear.shape[1]]
        print(temp)
        if not bl_s:
            linear_spect = temp
        else:
            linear_spect = tf.concat([linear_spect, temp],axis=0)
    '''

    # method 2:
    linear = tf.spectral.fft(tf.to_complex64(f * window))[:, :int(1 +
                                                                  n_fft // 2)]

    return linear
Exemplo n.º 6
0
def tf_ifft2c(kspace):
    shp=tf.shape(kspace)
    scale=tf.sqrt(tf.to_float(shp[-2]*shp[-1]))
    scale=tf.to_complex64(scale)
    shifted=tf_shift2d(kspace)
    xhat=tf.spectral.ifft2d(shifted)*scale
    centered=tf_shift2d(xhat)
    return centered
Exemplo n.º 7
0
    def call(self, x):
        """
        This method must be defined for any custom layer, it is where the calculations are done.   
        
        x: The tensor representing the input to the layer. This is passed automatically by tensorflow. 
        """
        # make sure the datatype is complex64, otherwise training will not work
        Hamiltonian = tf.to_complex64(x)

        # evaluate -i*H*l
        Hamiltonian = Hamiltonian * self.length

        #evaluate U =expm(-i*H*l)
        U = tf.linalg.expm(Hamiltonian)

        # add an extra dimenstion to the tensor representing initial state, to represent time
        psi_0 = tf.expand_dims(self.initial_state, 0)

        # add another dimension to represent batch
        psi_0 = tf.expand_dims(psi_0, 0)

        # construct a tensor in the form of a row vector whose elements are [d1,d2,1,1], where d1 and d2 correspond to the
        # number of examples and number of time steps of the inpu
        temp_shape = tf.concat(
            [tf.shape(x)[0:2],
             tf.constant(np.array([1, 1], dtype=np.int32))], 0)

        # repeat the input ket colmun along the batch and time dimensions, and convert to complex64 datatype
        psi_0 = tf.tile(psi_0, temp_shape)
        psi_0 = tf.to_complex64(psi_0)

        # evalaue U \psi_0
        prob = tf.matmul(U, psi_0)

        # remove the last dimension since we have a column rather than a matrix
        prob = tf.squeeze(prob, -1)

        # calculate the amplitude for each entry
        prob = tf.square(tf.abs(prob))
        return prob
Exemplo n.º 8
0
 def _after_czs(self, v: tf.Tensor, pairs: tf.Tensor) -> tf.Tensor:
     iota = tf.range(self.grouping.system_size())
     t = tf.constant(0, dtype=tf.int32)
     for k in range(pairs.shape[0]):
         i = pairs[k, 0]
         j = pairs[k, 1]
         index_mask = tf.bitwise.bitwise_or(tf.bitwise.left_shift(1, i),
                                            tf.bitwise.left_shift(1, j))
         index_mask = tf.cond(tf.math.equal(i, -1), lambda: -1,
                              lambda: index_mask)
         masked_iota = tf.bitwise.bitwise_and(iota, index_mask)
         kept_iota = tf.math.equal(index_mask, masked_iota)
         t = tf.bitwise.bitwise_xor(t, tf.to_int32(kept_iota))
     negations = 1 - tf.to_complex64(t) * 2
     v *= negations
     return v
    def generate_background(self):
        # illumination background;

        kx_illum, ky_illum, kz_illum = tf.split(self.k_illum_vectors,
                                                3,
                                                axis=0)  # num illum, num LEDs
        kx_illum = kx_illum[0]  # remove the split dimension
        ky_illum = ky_illum[0]
        kz_illum = kz_illum[0]

        # create mask that zeros out illuminations that miss the aperture:
        # reshape to _ by 1
        self.miss_aper_mask = tf.to_float(
            tf.less(kx_illum**2 + ky_illum**2,
                    (self.k_illum * self.NA)**2))[0, :, None]

        # if shifting bowls to force passage thru DC, modify illumination kxy:
        if self.force_pass_thru_DC:
            # kx_illum is num illum x num LEDs
            # DC_adjust is num_LEDs x 3
            kx_illum += (self.DC_adjust[None, :, 0]
                         ) * self.k_max[0] * 2 / self.side_k[0]
            ky_illum += (self.DC_adjust[None, :, 1]
                         ) * self.k_max[1] * 2 / self.side_k[1]
            kz_illum += (self.DC_adjust[None, :, 2]
                         ) * self.k_max[2] * 2 / self.side_k[2]

            # renormalize magnitude to k_illum:
            k_mag = tf.sqrt(kx_illum**2 + ky_illum**2 + kz_illum**2)
            kx_illum *= self.k_illum / k_mag
            ky_illum *= self.k_illum / k_mag

        # generate 2D phase ramp, for 0-reference fft:
        xy_samp = np.arange(self.xy_cap_n, dtype=np.float32)
        xy_samp -= np.ceil(self.xy_cap_n / 2)  # center
        xy_samp *= self.dxy_sample  # image coordinates
        x_samp, y_samp = tf.meshgrid(xy_samp, xy_samp)
        x_samp = tf.reshape(x_samp, [-1])
        y_samp = tf.reshape(y_samp, [-1])
        # shape: num illum, num LEDs, camx*camy:
        self.k_fft_shift = tf.exp(
            1j * 2 * np.pi *
            tf.to_complex64(x_samp[None, None, :] * kx_illum[:, :, None] +
                            y_samp[None, None, :] * ky_illum[:, :, None]))
        # squeeze for now, assuming one illumination for now:
        # this is actually already batched because derived from xyz_LED_batch:
        self.k_fft_shift = tf.squeeze(self.k_fft_shift)
Exemplo n.º 10
0
    def call(self, x):
        """
        This method must be defined for any custom layer, it is where the calculations are done.   
        
        x: The tensor representing the input to the layer. This is passed automatically by tensorflow. 
        """
        # evaluate -i*H*l
        Hamiltonian = x * self.length

        #evaluate U =expm(-i*H*l)
        U = tf.linalg.expm(Hamiltonian)

        # add an extra dimenstion to the tensor representing initial state, to represent time
        psi_0 = tf.expand_dims(self.initial_state, 0)

        # add another dimension to represent batch
        psi_0 = tf.expand_dims(psi_0, 0)

        # construct a tensor in the form of a row vector whose elements are [d1,d2,1,1], where d1 and d2 correspond to the
        # number of examples and number of time steps of the input
        temp_shape = tf.concat(
            [tf.shape(x)[0:2],
             tf.constant(np.array([1, 1], dtype=np.int32))], 0)

        # repeat the input ket colmun along the batch and time dimensions, and convert to complex64 datatype
        psi_0 = tf.tile(psi_0, temp_shape)
        psi_0 = tf.to_complex64(psi_0)

        # evalaue U \psi_0
        psi_t = tf.squeeze(tf.matmul(U, psi_0), -1)

        # calculate the interferometer power distribution
        power_distribution = tf.square(tf.abs(0.5 * (1 + psi_t)))
        interferometer_distribution = tf.square(tf.abs(0.5 * (1j + psi_t)))

        # concatentate the amplitudes and relative phases over each other
        output = tf.concat([power_distribution, interferometer_distribution],
                           -1)

        return output
Exemplo n.º 11
0
 def to_complex(self, x):
     return tf.to_complex64(x)
    def reconstruct_with_born(self):
        # use intensity (no phase) data and try to reconstruct 3D index distribution;

        if self.optimize_k_directly:  # tf variables are k space
            self.initialize_k_space_domain()
        else:  # tf variables are space domain
            self.initialize_space_space_domain()

        # DT_recon is the scattering potiential; then to get RI:
        self.RI = self.V_to_RI(self.DT_recon)

        # generate k-spherical caps:
        self.generate_cap()
        self.generate_apertures()
        self.subtract_illumination()

        # already batched, because derived from xyz_LED_batch:
        self.k_fft_shift_batch = self.k_fft_shift
        self.xyz_caps_batch = self.xyz_caps

        self.pupil_phase = tf.Variable(np.zeros(
            (self.xy_cap_n, self.xy_cap_n)),
                                       dtype=tf.float32,
                                       name='pupil_phase_function')
        pupil = tf.exp(1j * tf.to_complex64(self.pupil_phase))

        # error between prediction and data:
        k_space_T = tf.transpose(self.k_space, [1, 0, 2])
        forward_fourier = self.tf_gather_nd3(k_space_T, self.xyz_caps_batch)
        forward_fourier /= tf.complex(
            0., self.kz_cap[None]
        ) * 2  # prefactor; it's 1i*kz/pi, but my kz is not in angular frequency
        forward_fourier = tf.reshape(
            forward_fourier,  # so we can do ifft
            (
                -1,
                len(self.k_illum),  # self.batch_size
                self.xy_cap_n,
                self.xy_cap_n))
        # zero out fourier support outside aperture before fftshift:
        forward_fourier *= tf.complex(self.aperture_mask[None], 0.)
        if self.pupil_function:
            forward_fourier *= pupil
        self.forward_pred = self.tf_ifftshift2(
            tf.ifft2d(self.tf_fftshift2(forward_fourier)))
        # fft phase factor compensation:
        self.forward_pred *= tf.to_complex64(self.dxy**2 * self.dz /
                                             self.dxy_sample**2)
        self.forward_pred = tf.reshape(
            self.forward_pred,  # reflatten
            (-1, self.points_per_cap))  # self.batch_size
        self.field = tf.identity(
            self.forward_pred
        )  # to monitor the E field for diagnostic purposes
        unscattered = self.DC_batch * self.k_fft_shift_batch * tf.exp(
            1j * tf.to_complex64(self.illumination_phase_batch[:, None]))

        if self.zero_out_background_if_outside_aper:
            # to zero out background from illumination angles that miss the aperture
            self.miss_aper_mask_batch = tf.to_complex64(self.miss_aper_mask)
            self.forward_pred_field = self.DC_batch * self.forward_pred + unscattered * self.miss_aper_mask_batch
            self.forward_pred = tf.abs(self.forward_pred_field)
        else:
            self.forward_pred_field = self.DC_batch * self.forward_pred + unscattered
            self.forward_pred = tf.abs(self.forward_pred_field)

        self.generate_train_ops()
    def reconstruct_with_multislice(self):
        # only two parameterization options: direct index recon, or DIP index recon;

        assert self.force_pass_thru_DC is False  # bowls are not generated, so this can't be done
        assert self.optimize_k_directly is False  # we are not using k-spheres

        self.k_illum_vectors = self.xyz_LED_batch[:, None, :] * self.k_illum[
            None, :, None]
        self.generate_background(
        )  # generates the variables needed for the background illumination
        self.k_fft_shift_batch = tf.conj(self.k_fft_shift)

        self.initialize_space_space_domain()
        self.RI = self.DT_recon + self.n_back  # no reference to scattering potential

        if self.use_spatial_patching:
            self.spatial_patching()
            if self.use_deep_image_prior:
                # DT recon is already generated from the spatially cropped input to DIP
                DT_recon = self.DT_recon
            else:
                DT_recon = self.DT_recon_sbatch
        else:
            DT_recon = self.DT_recon

        # fresnel propagation kernel:
        # fix the squeezing in the future if using more than one color
        k0 = np.squeeze(self.k_vacuum)
        kn = np.squeeze(self.k_illum)
        self.generate_k_coordinates()
        kx = tf.to_complex64(tf.squeeze(self.kx_cap))
        ky = tf.to_complex64(tf.squeeze(self.ky_cap))
        self.k_2 = kx**2 + ky**2
        self.F = tf.exp(-1j * 2 * np.pi * self.k_2 * self.dz /
                        (kn + tf.sqrt(kn**2 - self.k_2)))
        self.F *= tf.squeeze(
            tf.to_complex64(self.evanescent_mask)
        )  # technically not needed, but due to numerical instabilities...
        self.F = self.tf_fftshift2(self.F)
        self.F = tf.to_complex64(
            self.F, name='fresnel_kernel')  # shape: xy_cap_n by xy_cap_n

        # shape: num caps by points per cap:
        self.illumination = self.DC_batch * self.k_fft_shift_batch  # called unscattered in reconstruct_with_born
        self.illumination = tf.reshape(self.illumination,
                                       [-1, self.xy_cap_n, self.xy_cap_n])

        # incorporate additional defocus factor to account for unknown focal position after propagating through sample;
        # 0 corresponds to the center of the sample; distance in um;
        # change the initial position of the beam so that after refocusing, the beam is at the center of the fov;
        self.focus = tf.Variable(self.focus_init,
                                 dtype=tf.float32,
                                 name='focal_position')

        # create apodizing Gaussian window:
        # use tf.contrib.image.translate rather than recompute for every LED to save time/memory:
        k_max_radius = 1 / 2 / self.dxy_sample  # max possible radius
        # compute shifts (using LED positions):
        x_shift = -(self.focus - self.sample_thickness /
                    2) * self.xyz_LED[0] / self.xyz_LED[2]
        y_shift = -(self.focus - self.sample_thickness /
                    2) * self.xyz_LED[1] / self.xyz_LED[2]
        self.xy_shift = tf.stack([x_shift, y_shift],
                                 axis=1) / self.dxy  # convert to pixel
        # centered, unshifted gaussian window
        gausswin0 = tf.exp(-tf.to_float(self.k_2) / 2 /
                           (k_max_radius * self.apod_frac)**2)
        gausswin = tf.tile(gausswin0[None], (self.num_caps, 1, 1))
        gausswin = tf.contrib.image.translate(gausswin[:, :, :, None],
                                              self.xy_shift, 'bilinear')
        self.gausswin = gausswin[:, :, :, 0]  # get rid of color channels
        self.gausswin_batch = tf.gather(self.gausswin, self.batch_inds)

        self.illumination *= tf.to_complex64(
            self.gausswin_batch)  # gaussian window

        # forward propagation:
        def propagate_1layer(field, t_i):
            # field: the input field;
            # t_i, the 2D object transmittance function at the current (ith) plane, referenced to background index;
            return tf.ifft2d(tf.fft2d(field) * self.F) * t_i

        dN = tf.transpose(DT_recon, [2, 0, 1])  # make z the leading dim
        t = tf.exp(1j * 2 * np.pi * k0 * dN *
                   self.dz)  # transmittance function
        self.propped = tf.scan(propagate_1layer,
                               initializer=self.illumination,
                               elems=t,
                               swap_memory=True)
        self.propped = tf.transpose(self.propped,
                                    [1, 2, 3, 0])  # num ill, x, y, z

        self.pupil_phase = tf.Variable(np.zeros(
            (self.xy_cap_n, self.xy_cap_n)),
                                       dtype=tf.float32,
                                       name='pupil_phase_function')
        pupil = tf.exp(1j * tf.to_complex64(self.pupil_phase))
        limiting_aperture = tf.squeeze(tf.to_complex64(self.aperture_mask))
        k_2 = self.k_2 * limiting_aperture  # to prevent values far away from origin from being too large
        self.F_to_focus = tf.exp(
            -1j * 2 * np.pi * k_2 *
            tf.to_complex64(-self.focus - self.sample_thickness / 2) /
            (kn + tf.sqrt(kn**2 - k_2)))
        # restrict to the experimental aperture
        self.F_to_focus *= limiting_aperture
        self.F_to_focus *= pupil  # to account for aberrations common to all
        self.F_to_focus = self.tf_fftshift2(self.F_to_focus)
        self.F_to_focus = tf.to_complex64(self.F_to_focus,
                                          name='fresnel_kernel_prop_to_focus')

        self.field = tf.ifft2d(
            tf.fft2d(self.propped[:, :, :, -1]) * self.F_to_focus[None])
        self.forward_pred = tf.abs(self.field)
        self.forward_pred = tf.reshape(self.forward_pred,
                                       [-1, self.xy_cap_n**2])

        self.data_batch *= tf.reshape(
            gausswin0,
            [-1])[None]  # since prediction is windowed, also window data
        self.generate_train_ops()
    def format_DT_data(self, stack, DC=None):
        # expects an input stack of shape: num aper, num LEDs, num illum, camx, camy;
        # do not take sqrt of the data -- that is done here;

        s = stack.shape
        assert self.num_apers == s[0]
        if not self.use_spatial_patching:
            # if using spatial patching, then s[3]=s[4]>xy_cap_n
            assert s[3] == s[4] == self.xy_cap_n
        else:
            assert s[3] == s[4] == self.xy_full_n

        self.num_caps = s[0] * s[1]  # number of spherical caps (aper*LED)
        self.points_per_cap = s[2] * self.xy_cap_n**2  # for every color

        self.data_stack = np.reshape(stack, (self.num_caps, s[3]**2))
        self.data_stack = np.sqrt(
            self.data_stack
        )  # so that we don't have to do this for each new batch

        # DC due to unscattered light, potentially different for every angle:
        if DC is None:
            # initialize from data
            DC = np.median(self.data_stack, 1)
        self.DC = tf.Variable(DC, dtype=np.float32, name='DC')
        self.illumination_phase = tf.Variable(tf.zeros(self.num_caps,
                                                       dtype=tf.float32),
                                              name='illumination_phase',
                                              trainable=False)

        self.generate_LED_positions_flat_array()

        if self.use_spatial_patching:
            # this implementation doesn't finish all the LEDs in one spatial crop before moving to another;

            # upper left hand corner of the crop to be made:
            self.spatial_batch_inds = tf.random_uniform(shape=(2, 1),
                                                        minval=0,
                                                        maxval=self.xy_full_n -
                                                        self.xy_cap_n,
                                                        dtype=tf.int32)
            # batch along LED dimension:
            self.dataset = (tf.data.Dataset.range(self.num_caps).shuffle(
                self.num_caps).batch(
                    self.batch_size).repeat(None).make_one_shot_iterator())
            self.batch_inds = self.dataset.get_next()
            # reshape so that we can crop:
            self.data_stack = self.data_stack.reshape(self.num_caps,
                                                      self.xy_full_n,
                                                      self.xy_full_n)
        else:
            # generate dataset for batching:
            self.dataset = tf.data.Dataset.from_tensor_slices(
                (self.data_stack, tf.range(self.num_caps)))
            if self.batch_size != self.num_caps:
                # if all examples are present, don't shuffle
                self.dataset = self.dataset.shuffle(self.num_caps)
            self.dataset = self.dataset.batch(self.batch_size)
            self.dataset = self.dataset.repeat(None)  # go forever
            self.batcher = self.dataset.make_one_shot_iterator()
            (self.data_batch, self.batch_inds) = self.batcher.get_next()

        if self.data_ignore is not None:
            keep_inds = tf.gather(~self.data_ignore, self.batch_inds)
            self.batch_inds = tf.boolean_mask(self.batch_inds, keep_inds)
            if not self.use_spatial_patching:
                # data batch is generated using data_inds for spatial patching
                self.data_batch = tf.boolean_mask(self.data_batch, keep_inds)

        self.DC_batch = tf.gather(self.DC, self.batch_inds)
        self.DC_batch = tf.to_complex64(self.DC_batch[:, None])
        self.illumination_phase_batch = tf.gather(self.illumination_phase,
                                                  self.batch_inds)
        self.xyz_LED_batch = tf.transpose(  # transpose because first dim is 3 for xyz
            tf.gather(tf.transpose(self.xyz_LED), self.batch_inds))
x_0 = tf.ones([tf.shape(x)[0], 1])
Feature = tf.concat([tf.multiply(a[0], x_0), tf.multiply(a[1], x)], axis=1)

for i in range(p):

    h = np.random.randint(low=0, high=pro_dim, size=[input_dim, 1])
    s = np.random.randint(low=0, high=2, size=[input_dim, 1]) * 2 - 1
    M_ = np.zeros(shape=[pro_dim, input_dim], dtype=np.float32)

    for j in range(input_dim):
        M_[h[j, 0], j] = s[j, 0]

    M = tf.transpose(M_)

    CountSketch = tf.to_complex64(tf.matmul(x, M))

    P = tf.multiply(P, tf.fft2d(CountSketch))

    Feature_ = tf.multiply(a[i], tf.real(tf.ifft2d(P)))

    if i > 1:
        Feature = tf.concat([Feature, Feature_], axis=1)

    print("Feature shape:  ")
    print(Feature.shape)

W = tf.Variable(
    tf.random_normal([1 + input_dim + (p - 2) * pro_dim, 1], stddev=0.35))

b = tf.Variable(tf.zeros([1]))