def _build_graph(self, x_train, global_step, hm_reg_scale, height_map_noise): ''' Builds the graph for this model. :param x_train (graph node): input image. :param global_step: Global step variable (unused in this model) :param hm_reg_scale: Regularization coefficient for laplace l1 regularizer of phaseplate :param height_map_noise: Noise added to height map to account for manufacturing deficiencies :return: graph node for output image ''' input_img = x_train with tf.device('/device:GPU:0'): # Input field is a planar wave. input_field = tf.ones((1, self.wave_res[0], self.wave_res[1], len(self.wave_lengths))) # Planar wave hits aperture: phase is shifted by phaseplate field = optics.height_map_element( input_field, wave_lengths=self.wave_lengths, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=None, height_tolerance=height_map_noise, refractive_idcs=self.refractive_idcs, name='height_map_optics') field = optics.circular_aperture(field) # Propagate field from aperture to sensor field = optics.propagate_fresnel( field, distance=self.sensor_distance, sampling_interval=self.sample_interval, wave_lengths=self.wave_lengths) # The psf is the intensities of the propagated field. psfs = optics.get_intensities(field) # Downsample psf to image resolution & normalize to sum to 1 psfs = optics.area_downsampling_tf(psfs, self.patch_size) psfs = tf.div(psfs, tf.reduce_sum(psfs, axis=[1, 2], keepdims=True)) optics.attach_summaries('PSF', psfs, image=True, log_image=True) # Image formation: PSF is convolved with input image psfs = tf.transpose(psfs, [1, 2, 0, 3]) output_image = optics.img_psf_conv(input_img, psfs) output_image = tf.cast(output_image, tf.float32) optics.attach_summaries('output_image', output_image, image=True, log_image=False) output_image += tf.random_uniform(minval=0.001, maxval=0.02, shape=[]) return output_image
def forward_model(input_field): field = optics.height_map_element( input_field, wave_lengths=self.wave_lengths, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=None, height_tolerance=height_map_noise, refractive_idcs=self.refractive_idcs, name='height_map_optics') field = optics.circular_aperture(field) field = optics.propagate_fresnel( field, distance=self.distance, input_sample_interval=self.input_sample_interval, wave_lengths=self.wave_lengths) return field
def _optical_conv_layer(self, input_field, hm_reg_scale, activation=None, coherent=False, name='optical_conv'): with tf.variable_scope(name): sensordims = self.wave_resolution input_field = tf.cast(input_field, tf.complex128) # Zero-centered fft2 of input field field = optics.fftshift2d_tf(optics.transp_fft2d(optics.ifftshift2d_tf(input_field))) # Build a phase mask, zero-centered height_map_initializer=tf.random_uniform_initializer(minval=0.999e-4, maxval=1.001e-4) # height_map_initializer=None pm = optics.height_map_element([1,self.wave_resolution[0],self.wave_resolution[1],1], wave_lengths=self.wavelength, height_map_initializer=height_map_initializer, name='phase_mask_height', refractive_index=self.n) # height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale), # Get ATF and PSF otf = tf.ones([1,self.wave_resolution[0],self.wave_resolution[1],1]) otf = optics.circular_aperture(otf, max_val = self.r_NA) otf = pm(otf) psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(otf))) psf = optics.Sensor(input_is_intensities=False, resolution=sensordims)(psf) psf /= tf.reduce_sum(psf) # conservation of energy psf = tf.cast(psf, tf.float32) optics.attach_img('psf', psf) # Get the output image if coherent: field = optics.circular_aperture(field, max_val = self.r_NA) field = pm(field) tf.summary.image('field', tf.square(tf.abs(field))) output_img = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(field))) else: psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) output_img = tf.abs(optics.fft_conv2d(input_field, psf)) # Apply nonlinear activation if activation is not None: output_img = activation(output_img) return output_img
def _build_graph(self, x_train, hm_reg_scale, hm_init_type='random_normal'): with tf.device('/device:GPU:0'): sensordims = (self.dim, self.dim) # Start with input image input_img = x_train / tf.reduce_sum(x_train) tf.summary.image('input_image', x_train) # fftshift(fft2(ifftshift( FIELD ))), zero-centered field = optics.fftshift2d_tf( optics.transp_fft2d(optics.ifftshift2d_tf(input_img))) # Build a phase mask, zero-centered height_map_initializer = tf.random_uniform_initializer( minval=0.999e-4, maxval=1.001e-4) # height_map_initializer=None pm = optics.height_map_element( [1, self.wave_resolution[0], self.wave_resolution[1], 1], wave_lengths=self.wave_length, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=height_map_initializer, name='phase_mask_height', refractive_index=self.n) # Get ATF and PSF otf = tf.ones( [1, self.wave_resolution[0], self.wave_resolution[1], 1]) otf = optics.circular_aperture(otf, max_val=self.r_NA) otf = pm(otf) psf = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(otf))) psf = optics.Sensor(input_is_intensities=False, resolution=sensordims)(psf) psf /= tf.reduce_sum(psf) # sum or max? psf = tf.cast(psf, tf.float32) optics.attach_img('recon_psf', psf) # Get the output image coherent = False if coherent: field = optics.circular_aperture(field, max_val=self.r_NA) field = pm(field) tf.summary.image('field', tf.square(tf.abs(field))) field = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(field))) output_img = optics.Sensor(input_is_intensities=False, resolution=(sensordims))(field) else: psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) output_img = tf.abs(optics.fft_conv2d(input_img, psf)) output_img = optics.Sensor(input_is_intensities=True, resolution=(sensordims))(output_img) output_img /= tf.reduce_sum(output_img) # sum or max? output_img = tf.cast(output_img, tf.float32) # output_img = tf.transpose(output_img, [1,2,0,3]) # (height, width, 1, 1) # Attach images to summary tf.summary.image('output_image', output_img) return output_img
def optical_conv_layer(input_field, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9, activation=None, coherent=False, amplitude_mask=False, zernike=False, fourier=False, binarymask=False, n_modes=1024, freq_range=.5, initializer=None, zernike_file='zernike_volume_256.npy', binary_mask_np=None, name='optical_conv'): dims = input_field.get_shape().as_list() with tf.variable_scope(name): if initializer is None: initializer = tf.random_uniform_initializer(minval=0.999e-4, maxval=1.001e-4) if amplitude_mask: # Build an amplitude mask, zero-centered # amplitude_map_initializer=tf.random_uniform_initializer(minval=0.1e-3, maxval=.1e-2) mask = optics.amplitude_map_element( [1, dims[1], dims[2], 1], r_NA, amplitude_map_initializer=initializer, name='amplitude_mask') else: # Build a phase mask, zero-centered if zernike: zernike_modes = np.load(zernike_file)[:n_modes, :, :] zernike_modes = tf.expand_dims(zernike_modes, -1) zernike_modes = tf.image.resize_images(zernike_modes, size=(dims[1], dims[2])) zernike_modes = tf.squeeze(zernike_modes, -1) mask = optics.zernike_element(zernike_modes, 'zernike_element', wavelength, n, r_NA, zernike_initializer=initializer) elif fourier: mask = optics.fourier_element( [1, dims[1], dims[2], 1], 'fourier_element', wave_lengths=wavelength, refractive_index=n, frequency_range=freq_range, height_map_regularizer=None, height_tolerance=None, # Default height tolerance is 2 nm. ) else: # height_map_initializer=None mask = optics.height_map_element( [1, dims[1], dims[2], 1], wave_lengths=wavelength, height_map_initializer=initializer, #height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale), name='phase_mask_height', refractive_index=n) # Get ATF and PSF atf = tf.ones([1, dims[1], dims[2], 1]) #zernike=True if zernike: atf = optics.circular_aperture(atf, max_val=r_NA) atf = mask(atf) # apply any additional binary amplitude mask [1, dim, dim, 1] if binarymask: binary_mask = tf.convert_to_tensor(binary_mask_np, dtype=tf.float32) binary_mask = tf.expand_dims(tf.expand_dims(binary_mask, 0), -1) # optics.attach_img('binary_mask', binary_mask) atf = atf * tf.cast(binary_mask, tf.complex128) optics.attach_img('atf', tf.abs(binary_mask)) # psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(atf))) psfc = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(atf))) psf = optics.Sensor(input_is_intensities=False, resolution=(dims[1], dims[2]))(psfc) psf /= tf.reduce_sum(psf) # conservation of energy psf = tf.cast(psf, tf.float32) optics.attach_summaries('psf', psf, True, True) # Get the output image if coherent: input_field = tf.cast(input_field, tf.complex128) # Zero-centered fft2 of input field field = optics.fftshift2d_tf( optics.transp_fft2d(optics.ifftshift2d_tf(input_field))) # field = optics.circular_aperture(field, max_val = r_NA) field = atf * field tf.summary.image('field', tf.log(tf.square(tf.abs(field)))) output_img = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(field))) output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1], dims[2]))(output_img) # does this need padding as well? # psfc = tf.expand_dims(tf.expand_dims(tf.squeeze(psfc), -1), -1) # padamt = int(dims[1]/2) # output_img = optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psfc, padamt), adjoint=False) # output_img = fftunpad(output_img, padamt) # output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1],dims[2]))(output_img) else: psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) # psf_flip = tf.reverse(psf,[0,1]) # output_img = conv2d(input_field, psf) padamt = int(dims[1] / 2) output_img = tf.abs( optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psf, padamt), adjoint=False)) output_img = fftunpad(output_img, padamt) # Apply nonlinear activation if activation is not None: output_img = activation(output_img) return output_img