def tf_FSPAS_FFT(self, field, wlength, z, dx, dy, ridx,theta_max): """ Angular Spectrum Propagation of Coherent Wave Fields with optional filtering INPUTS : U, wave-field in space domain wlenght : MULTI-wavelengthS in the optical wave z : distance of propagation dx,dy : sampling intervals in space M,N : Size of simulation window theta0 : Optional BAndwidth Limitation in DEGREES (if no filtering is desired, only EVANESCENT WAVE IS FILTERED) OUTPUT : output, propagated wave-field in space domain """ B, C, M, N = self.B, self.C, self.M, self.N output_shape = tf.stack([B, C, tf.to_int32(M),tf.to_int32(N)]) wlengtheff = wlength/ridx dfx = 1/dx/M dfy = 1/dy/N fx = tf.expand_dims((tf.range(M,dtype=tf.float32)-(M)/2)*dfx,-1) fy = tf.expand_dims((tf.range(N,dtype=tf.float32)-(N)/2)*dfy,0) fx2 = tf.matmul(fx**2,tf.ones((1,N),dtype=tf.float32)) fy2 = tf.matmul(tf.ones((M,1),dtype=tf.float32),fy**2) wlengtheff = tf.expand_dims(tf.expand_dims(wlengtheff,-1),-1) # Diffraction limit f0 = 1.0/wlengtheff Qd = tf.to_float(tf.less(tf.expand_dims((fx2+fy2),0),(f0**2))) # Prop Anti-aliasing Qbw = self.adaptive_bandlimit(fx,fy,z,dfx,dfy,wlengtheff,theta_max) #!!! CHECK THIS Q = Qd*Qbw W = Q*tf.expand_dims((fx2+fy2),0)*(wlengtheff**2) phase_term_sqrt = (tf.ones((C,tf.to_int32(M),tf.to_int32(N)))-W)**(0.5) Hphase = 2*np.pi/wlengtheff*z*phase_term_sqrt HFSP = tf.complex(Q*tf.cos(Hphase),Q*tf.sin(Hphase)) ASpectrum = tf.fft2d(field) ASpectrum = self.tf_fft_shift_2d(ASpectrum) ASpectrum_z = self.tf_ifft_shift_2d(tf.multiply(HFSP,ASpectrum)) output = tf.ifft2d(ASpectrum_z) Qd = tf.to_complex64(Qd) Q = tf.to_complex64(Q) unitary_prop_cnst = tf.divide(tf.reduce_sum(tf.square(tf.abs(ASpectrum)),axis=[2,3],keepdims=True),tf.reduce_sum(tf.square(tf.abs(ASpectrum*Qd)),axis=[2,3],keepdims=True)) unitary_cnst = tf.complex(tf.sqrt(unitary_prop_cnst),tf.zeros_like(unitary_prop_cnst)) output = tf.multiply(output,unitary_cnst) output = tf.slice(output, tf.stack([0, 0, 0, 0]), output_shape) scattered_pwr = tf.reduce_sum(tf.square(tf.abs(unitary_cnst*ASpectrum*Qd*(Qd-Q))),axis=[2,3],keepdims=True)/M/N return output, scattered_pwr
def tf_SE3_to_se3(SE3): ''' input : SE3 [-1, 4, 4] output: se3 [-1,6] ''' # do not use tf.linalg.logm() because it doesn't have gradient graph R = SE3[:, 0:3, 0:3] # [N,3,3] T = SE3[:, 0:3, 3] # [N, 3] thetha = tf.expand_dims(tf.acos((tf.trace(R) - 1) / 2), 1) thetha = tf.expand_dims(thetha, 2) + 1e-5 R = tf.to_complex64(R) w_hat = tf.linalg.logm(R) # [N,3,3] w_hat = tf.to_float(w_hat) w = tf_vee(w_hat) # [N,3] v = tf.matmul( tf.eye(3) - (1 / 2 * (w_hat)) + (((1 / (thetha * thetha)) * (1 - ((thetha * tf.sin(thetha)) / (2 * (1 - tf.cos(thetha)))))) * tf.matmul(w_hat, w_hat)), tf.expand_dims(T, 2)) v = v[:, :, 0] se3 = tf.concat([v, w], axis=1) return se3
def to_complex(self, x): if self.dtype(x) in (np.complex64, np.complex128): return x if self.dtype(x) == np.float64: return tf.to_complex128(x) else: return tf.to_complex64(x)
def call(self, x): """ This method must be defined for any custom layer, it is where the calculations are done. x: a tensor representing the inputs to the layer. This is passed automatically by tensorflow. """ # add zero voltage parameters to the input H = tf.to_complex64(x + self.H0) # retreive the complex opertor complex_operator = tf.constant(self.complex_operator, dtype=tf.complex64) # add two extra dimensions for batch and time complex_operator = tf.expand_dims(complex_operator, 0) complex_operator = tf.expand_dims(complex_operator, 0) # construct a tensor in the form of a row vector whose elements are [d1,d2,1,1], where d1 and d2 correspond to the # number of examples and number of time steps of the input temp_shape = tf.concat( [tf.shape(x)[0:2], tf.constant(np.array([1, 1], dtype=np.int32))], 0) # repeat the input ket colmun along the batch and time dimensions complex_operator = tf.tile(complex_operator, temp_shape) # apply the complex operator to convert lower traingular part into pure imaginary H = tf.multiply(H, complex_operator) # convert to symmetric matrix by doing H+H' [permute index 3 and 2] H = tf.add(H, tf.transpose(H, [0, 1, 3, 2], conjugate=True)) return H
def stft_tf(wav, win_length, hop_length, n_fft, window='hann', mode='REFLECT'): ''' implement stft in tensorflow the output is same as librosa.stft with center=True in 10*-6 error link: https://github.com/zhang-wy15/stft_from_librosa_to_tensorflow ''' # By default, use the entire frame if win_length is None: win_length = n_fft # Set the default hop, if it's not already specified if hop_length is None: hop_length = int(win_length // 4) window = scipy.signal.get_window(window, win_length, fftbins=True) # Pad the window out to n_fft size window = np.pad(window, ((n_fft - win_length) // 2, (n_fft - win_length) // 2), mode='constant', constant_values=(0, 0)) # Reshape so that the window can be broadcast # We don't need this # window = window.reshape((-1,1)) # Pad the time series so that frames are centered center = True if center: wav = tf.pad(wav, [[n_fft // 2, n_fft // 2]], mode=mode) # Window the time series. f = tf.contrib.signal.frame(wav, n_fft, hop_length, pad_end=False) # fft method 1: divide block and caculate fft separately # fft method 2: whole frame to tf.spectral.fft # result are same, but method 2 is faster # method 1: ''' linear = tf.zeros((f.shape[0],int(1 + n_fft // 2))) MAX_MEM_BLOCK = 2**8 * 2**10 itemsieze = 8 n_columns = int(MAX_MEM_BLOCK / (int(1 + n_fft // 2) * itemsieze)) for bl_s in range(0, linear.shape[0], n_columns): bl_t = min(bl_s + n_columns, linear.shape[0]) temp = tf.spectral.fft(tf.to_complex64(f[bl_s:bl_t,:] * window))[:,:linear.shape[1]] print(temp) if not bl_s: linear_spect = temp else: linear_spect = tf.concat([linear_spect, temp],axis=0) ''' # method 2: linear = tf.spectral.fft(tf.to_complex64(f * window))[:, :int(1 + n_fft // 2)] return linear
def tf_ifft2c(kspace): shp=tf.shape(kspace) scale=tf.sqrt(tf.to_float(shp[-2]*shp[-1])) scale=tf.to_complex64(scale) shifted=tf_shift2d(kspace) xhat=tf.spectral.ifft2d(shifted)*scale centered=tf_shift2d(xhat) return centered
def call(self, x): """ This method must be defined for any custom layer, it is where the calculations are done. x: The tensor representing the input to the layer. This is passed automatically by tensorflow. """ # make sure the datatype is complex64, otherwise training will not work Hamiltonian = tf.to_complex64(x) # evaluate -i*H*l Hamiltonian = Hamiltonian * self.length #evaluate U =expm(-i*H*l) U = tf.linalg.expm(Hamiltonian) # add an extra dimenstion to the tensor representing initial state, to represent time psi_0 = tf.expand_dims(self.initial_state, 0) # add another dimension to represent batch psi_0 = tf.expand_dims(psi_0, 0) # construct a tensor in the form of a row vector whose elements are [d1,d2,1,1], where d1 and d2 correspond to the # number of examples and number of time steps of the inpu temp_shape = tf.concat( [tf.shape(x)[0:2], tf.constant(np.array([1, 1], dtype=np.int32))], 0) # repeat the input ket colmun along the batch and time dimensions, and convert to complex64 datatype psi_0 = tf.tile(psi_0, temp_shape) psi_0 = tf.to_complex64(psi_0) # evalaue U \psi_0 prob = tf.matmul(U, psi_0) # remove the last dimension since we have a column rather than a matrix prob = tf.squeeze(prob, -1) # calculate the amplitude for each entry prob = tf.square(tf.abs(prob)) return prob
def _after_czs(self, v: tf.Tensor, pairs: tf.Tensor) -> tf.Tensor: iota = tf.range(self.grouping.system_size()) t = tf.constant(0, dtype=tf.int32) for k in range(pairs.shape[0]): i = pairs[k, 0] j = pairs[k, 1] index_mask = tf.bitwise.bitwise_or(tf.bitwise.left_shift(1, i), tf.bitwise.left_shift(1, j)) index_mask = tf.cond(tf.math.equal(i, -1), lambda: -1, lambda: index_mask) masked_iota = tf.bitwise.bitwise_and(iota, index_mask) kept_iota = tf.math.equal(index_mask, masked_iota) t = tf.bitwise.bitwise_xor(t, tf.to_int32(kept_iota)) negations = 1 - tf.to_complex64(t) * 2 v *= negations return v
def generate_background(self): # illumination background; kx_illum, ky_illum, kz_illum = tf.split(self.k_illum_vectors, 3, axis=0) # num illum, num LEDs kx_illum = kx_illum[0] # remove the split dimension ky_illum = ky_illum[0] kz_illum = kz_illum[0] # create mask that zeros out illuminations that miss the aperture: # reshape to _ by 1 self.miss_aper_mask = tf.to_float( tf.less(kx_illum**2 + ky_illum**2, (self.k_illum * self.NA)**2))[0, :, None] # if shifting bowls to force passage thru DC, modify illumination kxy: if self.force_pass_thru_DC: # kx_illum is num illum x num LEDs # DC_adjust is num_LEDs x 3 kx_illum += (self.DC_adjust[None, :, 0] ) * self.k_max[0] * 2 / self.side_k[0] ky_illum += (self.DC_adjust[None, :, 1] ) * self.k_max[1] * 2 / self.side_k[1] kz_illum += (self.DC_adjust[None, :, 2] ) * self.k_max[2] * 2 / self.side_k[2] # renormalize magnitude to k_illum: k_mag = tf.sqrt(kx_illum**2 + ky_illum**2 + kz_illum**2) kx_illum *= self.k_illum / k_mag ky_illum *= self.k_illum / k_mag # generate 2D phase ramp, for 0-reference fft: xy_samp = np.arange(self.xy_cap_n, dtype=np.float32) xy_samp -= np.ceil(self.xy_cap_n / 2) # center xy_samp *= self.dxy_sample # image coordinates x_samp, y_samp = tf.meshgrid(xy_samp, xy_samp) x_samp = tf.reshape(x_samp, [-1]) y_samp = tf.reshape(y_samp, [-1]) # shape: num illum, num LEDs, camx*camy: self.k_fft_shift = tf.exp( 1j * 2 * np.pi * tf.to_complex64(x_samp[None, None, :] * kx_illum[:, :, None] + y_samp[None, None, :] * ky_illum[:, :, None])) # squeeze for now, assuming one illumination for now: # this is actually already batched because derived from xyz_LED_batch: self.k_fft_shift = tf.squeeze(self.k_fft_shift)
def call(self, x): """ This method must be defined for any custom layer, it is where the calculations are done. x: The tensor representing the input to the layer. This is passed automatically by tensorflow. """ # evaluate -i*H*l Hamiltonian = x * self.length #evaluate U =expm(-i*H*l) U = tf.linalg.expm(Hamiltonian) # add an extra dimenstion to the tensor representing initial state, to represent time psi_0 = tf.expand_dims(self.initial_state, 0) # add another dimension to represent batch psi_0 = tf.expand_dims(psi_0, 0) # construct a tensor in the form of a row vector whose elements are [d1,d2,1,1], where d1 and d2 correspond to the # number of examples and number of time steps of the input temp_shape = tf.concat( [tf.shape(x)[0:2], tf.constant(np.array([1, 1], dtype=np.int32))], 0) # repeat the input ket colmun along the batch and time dimensions, and convert to complex64 datatype psi_0 = tf.tile(psi_0, temp_shape) psi_0 = tf.to_complex64(psi_0) # evalaue U \psi_0 psi_t = tf.squeeze(tf.matmul(U, psi_0), -1) # calculate the interferometer power distribution power_distribution = tf.square(tf.abs(0.5 * (1 + psi_t))) interferometer_distribution = tf.square(tf.abs(0.5 * (1j + psi_t))) # concatentate the amplitudes and relative phases over each other output = tf.concat([power_distribution, interferometer_distribution], -1) return output
def to_complex(self, x): return tf.to_complex64(x)
def reconstruct_with_born(self): # use intensity (no phase) data and try to reconstruct 3D index distribution; if self.optimize_k_directly: # tf variables are k space self.initialize_k_space_domain() else: # tf variables are space domain self.initialize_space_space_domain() # DT_recon is the scattering potiential; then to get RI: self.RI = self.V_to_RI(self.DT_recon) # generate k-spherical caps: self.generate_cap() self.generate_apertures() self.subtract_illumination() # already batched, because derived from xyz_LED_batch: self.k_fft_shift_batch = self.k_fft_shift self.xyz_caps_batch = self.xyz_caps self.pupil_phase = tf.Variable(np.zeros( (self.xy_cap_n, self.xy_cap_n)), dtype=tf.float32, name='pupil_phase_function') pupil = tf.exp(1j * tf.to_complex64(self.pupil_phase)) # error between prediction and data: k_space_T = tf.transpose(self.k_space, [1, 0, 2]) forward_fourier = self.tf_gather_nd3(k_space_T, self.xyz_caps_batch) forward_fourier /= tf.complex( 0., self.kz_cap[None] ) * 2 # prefactor; it's 1i*kz/pi, but my kz is not in angular frequency forward_fourier = tf.reshape( forward_fourier, # so we can do ifft ( -1, len(self.k_illum), # self.batch_size self.xy_cap_n, self.xy_cap_n)) # zero out fourier support outside aperture before fftshift: forward_fourier *= tf.complex(self.aperture_mask[None], 0.) if self.pupil_function: forward_fourier *= pupil self.forward_pred = self.tf_ifftshift2( tf.ifft2d(self.tf_fftshift2(forward_fourier))) # fft phase factor compensation: self.forward_pred *= tf.to_complex64(self.dxy**2 * self.dz / self.dxy_sample**2) self.forward_pred = tf.reshape( self.forward_pred, # reflatten (-1, self.points_per_cap)) # self.batch_size self.field = tf.identity( self.forward_pred ) # to monitor the E field for diagnostic purposes unscattered = self.DC_batch * self.k_fft_shift_batch * tf.exp( 1j * tf.to_complex64(self.illumination_phase_batch[:, None])) if self.zero_out_background_if_outside_aper: # to zero out background from illumination angles that miss the aperture self.miss_aper_mask_batch = tf.to_complex64(self.miss_aper_mask) self.forward_pred_field = self.DC_batch * self.forward_pred + unscattered * self.miss_aper_mask_batch self.forward_pred = tf.abs(self.forward_pred_field) else: self.forward_pred_field = self.DC_batch * self.forward_pred + unscattered self.forward_pred = tf.abs(self.forward_pred_field) self.generate_train_ops()
def reconstruct_with_multislice(self): # only two parameterization options: direct index recon, or DIP index recon; assert self.force_pass_thru_DC is False # bowls are not generated, so this can't be done assert self.optimize_k_directly is False # we are not using k-spheres self.k_illum_vectors = self.xyz_LED_batch[:, None, :] * self.k_illum[ None, :, None] self.generate_background( ) # generates the variables needed for the background illumination self.k_fft_shift_batch = tf.conj(self.k_fft_shift) self.initialize_space_space_domain() self.RI = self.DT_recon + self.n_back # no reference to scattering potential if self.use_spatial_patching: self.spatial_patching() if self.use_deep_image_prior: # DT recon is already generated from the spatially cropped input to DIP DT_recon = self.DT_recon else: DT_recon = self.DT_recon_sbatch else: DT_recon = self.DT_recon # fresnel propagation kernel: # fix the squeezing in the future if using more than one color k0 = np.squeeze(self.k_vacuum) kn = np.squeeze(self.k_illum) self.generate_k_coordinates() kx = tf.to_complex64(tf.squeeze(self.kx_cap)) ky = tf.to_complex64(tf.squeeze(self.ky_cap)) self.k_2 = kx**2 + ky**2 self.F = tf.exp(-1j * 2 * np.pi * self.k_2 * self.dz / (kn + tf.sqrt(kn**2 - self.k_2))) self.F *= tf.squeeze( tf.to_complex64(self.evanescent_mask) ) # technically not needed, but due to numerical instabilities... self.F = self.tf_fftshift2(self.F) self.F = tf.to_complex64( self.F, name='fresnel_kernel') # shape: xy_cap_n by xy_cap_n # shape: num caps by points per cap: self.illumination = self.DC_batch * self.k_fft_shift_batch # called unscattered in reconstruct_with_born self.illumination = tf.reshape(self.illumination, [-1, self.xy_cap_n, self.xy_cap_n]) # incorporate additional defocus factor to account for unknown focal position after propagating through sample; # 0 corresponds to the center of the sample; distance in um; # change the initial position of the beam so that after refocusing, the beam is at the center of the fov; self.focus = tf.Variable(self.focus_init, dtype=tf.float32, name='focal_position') # create apodizing Gaussian window: # use tf.contrib.image.translate rather than recompute for every LED to save time/memory: k_max_radius = 1 / 2 / self.dxy_sample # max possible radius # compute shifts (using LED positions): x_shift = -(self.focus - self.sample_thickness / 2) * self.xyz_LED[0] / self.xyz_LED[2] y_shift = -(self.focus - self.sample_thickness / 2) * self.xyz_LED[1] / self.xyz_LED[2] self.xy_shift = tf.stack([x_shift, y_shift], axis=1) / self.dxy # convert to pixel # centered, unshifted gaussian window gausswin0 = tf.exp(-tf.to_float(self.k_2) / 2 / (k_max_radius * self.apod_frac)**2) gausswin = tf.tile(gausswin0[None], (self.num_caps, 1, 1)) gausswin = tf.contrib.image.translate(gausswin[:, :, :, None], self.xy_shift, 'bilinear') self.gausswin = gausswin[:, :, :, 0] # get rid of color channels self.gausswin_batch = tf.gather(self.gausswin, self.batch_inds) self.illumination *= tf.to_complex64( self.gausswin_batch) # gaussian window # forward propagation: def propagate_1layer(field, t_i): # field: the input field; # t_i, the 2D object transmittance function at the current (ith) plane, referenced to background index; return tf.ifft2d(tf.fft2d(field) * self.F) * t_i dN = tf.transpose(DT_recon, [2, 0, 1]) # make z the leading dim t = tf.exp(1j * 2 * np.pi * k0 * dN * self.dz) # transmittance function self.propped = tf.scan(propagate_1layer, initializer=self.illumination, elems=t, swap_memory=True) self.propped = tf.transpose(self.propped, [1, 2, 3, 0]) # num ill, x, y, z self.pupil_phase = tf.Variable(np.zeros( (self.xy_cap_n, self.xy_cap_n)), dtype=tf.float32, name='pupil_phase_function') pupil = tf.exp(1j * tf.to_complex64(self.pupil_phase)) limiting_aperture = tf.squeeze(tf.to_complex64(self.aperture_mask)) k_2 = self.k_2 * limiting_aperture # to prevent values far away from origin from being too large self.F_to_focus = tf.exp( -1j * 2 * np.pi * k_2 * tf.to_complex64(-self.focus - self.sample_thickness / 2) / (kn + tf.sqrt(kn**2 - k_2))) # restrict to the experimental aperture self.F_to_focus *= limiting_aperture self.F_to_focus *= pupil # to account for aberrations common to all self.F_to_focus = self.tf_fftshift2(self.F_to_focus) self.F_to_focus = tf.to_complex64(self.F_to_focus, name='fresnel_kernel_prop_to_focus') self.field = tf.ifft2d( tf.fft2d(self.propped[:, :, :, -1]) * self.F_to_focus[None]) self.forward_pred = tf.abs(self.field) self.forward_pred = tf.reshape(self.forward_pred, [-1, self.xy_cap_n**2]) self.data_batch *= tf.reshape( gausswin0, [-1])[None] # since prediction is windowed, also window data self.generate_train_ops()
def format_DT_data(self, stack, DC=None): # expects an input stack of shape: num aper, num LEDs, num illum, camx, camy; # do not take sqrt of the data -- that is done here; s = stack.shape assert self.num_apers == s[0] if not self.use_spatial_patching: # if using spatial patching, then s[3]=s[4]>xy_cap_n assert s[3] == s[4] == self.xy_cap_n else: assert s[3] == s[4] == self.xy_full_n self.num_caps = s[0] * s[1] # number of spherical caps (aper*LED) self.points_per_cap = s[2] * self.xy_cap_n**2 # for every color self.data_stack = np.reshape(stack, (self.num_caps, s[3]**2)) self.data_stack = np.sqrt( self.data_stack ) # so that we don't have to do this for each new batch # DC due to unscattered light, potentially different for every angle: if DC is None: # initialize from data DC = np.median(self.data_stack, 1) self.DC = tf.Variable(DC, dtype=np.float32, name='DC') self.illumination_phase = tf.Variable(tf.zeros(self.num_caps, dtype=tf.float32), name='illumination_phase', trainable=False) self.generate_LED_positions_flat_array() if self.use_spatial_patching: # this implementation doesn't finish all the LEDs in one spatial crop before moving to another; # upper left hand corner of the crop to be made: self.spatial_batch_inds = tf.random_uniform(shape=(2, 1), minval=0, maxval=self.xy_full_n - self.xy_cap_n, dtype=tf.int32) # batch along LED dimension: self.dataset = (tf.data.Dataset.range(self.num_caps).shuffle( self.num_caps).batch( self.batch_size).repeat(None).make_one_shot_iterator()) self.batch_inds = self.dataset.get_next() # reshape so that we can crop: self.data_stack = self.data_stack.reshape(self.num_caps, self.xy_full_n, self.xy_full_n) else: # generate dataset for batching: self.dataset = tf.data.Dataset.from_tensor_slices( (self.data_stack, tf.range(self.num_caps))) if self.batch_size != self.num_caps: # if all examples are present, don't shuffle self.dataset = self.dataset.shuffle(self.num_caps) self.dataset = self.dataset.batch(self.batch_size) self.dataset = self.dataset.repeat(None) # go forever self.batcher = self.dataset.make_one_shot_iterator() (self.data_batch, self.batch_inds) = self.batcher.get_next() if self.data_ignore is not None: keep_inds = tf.gather(~self.data_ignore, self.batch_inds) self.batch_inds = tf.boolean_mask(self.batch_inds, keep_inds) if not self.use_spatial_patching: # data batch is generated using data_inds for spatial patching self.data_batch = tf.boolean_mask(self.data_batch, keep_inds) self.DC_batch = tf.gather(self.DC, self.batch_inds) self.DC_batch = tf.to_complex64(self.DC_batch[:, None]) self.illumination_phase_batch = tf.gather(self.illumination_phase, self.batch_inds) self.xyz_LED_batch = tf.transpose( # transpose because first dim is 3 for xyz tf.gather(tf.transpose(self.xyz_LED), self.batch_inds))
x_0 = tf.ones([tf.shape(x)[0], 1]) Feature = tf.concat([tf.multiply(a[0], x_0), tf.multiply(a[1], x)], axis=1) for i in range(p): h = np.random.randint(low=0, high=pro_dim, size=[input_dim, 1]) s = np.random.randint(low=0, high=2, size=[input_dim, 1]) * 2 - 1 M_ = np.zeros(shape=[pro_dim, input_dim], dtype=np.float32) for j in range(input_dim): M_[h[j, 0], j] = s[j, 0] M = tf.transpose(M_) CountSketch = tf.to_complex64(tf.matmul(x, M)) P = tf.multiply(P, tf.fft2d(CountSketch)) Feature_ = tf.multiply(a[i], tf.real(tf.ifft2d(P))) if i > 1: Feature = tf.concat([Feature, Feature_], axis=1) print("Feature shape: ") print(Feature.shape) W = tf.Variable( tf.random_normal([1 + input_dim + (p - 2) * pro_dim, 1], stddev=0.35)) b = tf.Variable(tf.zeros([1]))