def _build_graph(self, x_train, global_step, hm_reg_scale, height_map_noise): ''' Builds the graph for this model. :param x_train (graph node): input image. :param global_step: Global step variable (unused in this model) :param hm_reg_scale: Regularization coefficient for laplace l1 regularizer of phaseplate :param height_map_noise: Noise added to height map to account for manufacturing deficiencies :return: graph node for output image ''' input_img = x_train with tf.device('/device:GPU:0'): # Input field is a planar wave. input_field = tf.ones((1, self.wave_res[0], self.wave_res[1], len(self.wave_lengths))) # Planar wave hits aperture: phase is shifted by phaseplate field = optics.height_map_element( input_field, wave_lengths=self.wave_lengths, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=None, height_tolerance=height_map_noise, refractive_idcs=self.refractive_idcs, name='height_map_optics') field = optics.circular_aperture(field) # Propagate field from aperture to sensor field = optics.propagate_fresnel( field, distance=self.sensor_distance, sampling_interval=self.sample_interval, wave_lengths=self.wave_lengths) # The psf is the intensities of the propagated field. psfs = optics.get_intensities(field) # Downsample psf to image resolution & normalize to sum to 1 psfs = optics.area_downsampling_tf(psfs, self.patch_size) psfs = tf.div(psfs, tf.reduce_sum(psfs, axis=[1, 2], keepdims=True)) optics.attach_summaries('PSF', psfs, image=True, log_image=True) # Image formation: PSF is convolved with input image psfs = tf.transpose(psfs, [1, 2, 0, 3]) output_image = optics.img_psf_conv(input_img, psfs) output_image = tf.cast(output_image, tf.float32) optics.attach_summaries('output_image', output_image, image=True, log_image=False) output_image += tf.random_uniform(minval=0.001, maxval=0.02, shape=[]) return output_image
def _build_graph(self, x_train, hm_reg_scale, init_gamma, height_map_noise): input_img, depth_map = x_train with tf.device('/device:GPU:0'): height_map = optics.get_fourier_height_map(self.wave_resolution[0], 0.75, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale)) optical_system = optics.SingleLensSetup(height_map=height_map, wave_resolution=self.wave_resolution, wave_lengths=self.wave_lengths, sensor_distance=self.sensor_distance, sensor_resolution=(self.patch_size, self.patch_size), input_sample_interval=self.sampling_interval, refractive_idcs=self.refractive_idcs, height_tolerance=height_map_noise, use_planar_incidence=False, depth_bins=self.depth_bins, upsample=False, psf_resolution=self.wave_resolution, target_distance=None) sensor_img = optical_system.get_sensor_img(input_img=input_img, noise_sigma=None, depth_dependent=True, depth_map=depth_map) U_net = net.U_Net() output_image = U_net.build(sensor_img) optics.attach_summaries('output_image', output_image, image=True, log_image=False) return output_image
def _build_graph(self, x_train, global_step, hm_reg_scale, init_gamma, height_map_noise, learned_target_depth, hm_init_type='random_normal'): input_img, depth_map = x_train with tf.device('/device:GPU:0'): with tf.variable_scope("optics"): height_map = optics.get_fourier_height_map(self.wave_resolution[0], 0.625, height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale)) target_depth_initializer = tf.constant_initializer(1.) target_depth = tf.get_variable(name="target_depth", shape=(), dtype=tf.float32, trainable=True, initializer=target_depth_initializer) target_depth = tf.square(target_depth) tf.summary.scalar('target_depth', target_depth) optical_system = optics.SingleLensSetup(height_map=height_map, wave_resolution=self.wave_resolution, wave_lengths=self.wave_lengths, sensor_distance=self.distance, sensor_resolution=(self.patch_size, self.patch_size), input_sample_interval=self.input_sample_interval, refractive_idcs=self.refractive_idcs, height_tolerance=height_map_noise, use_planar_incidence=False, depth_bins=self.depth_bins, upsample=False, psf_resolution=self.wave_resolution, target_distance=target_depth) noise_sigma = tf.random_uniform(minval=0.001, maxval=0.02, shape=[]) sensor_img = optical_system.get_sensor_img(input_img=input_img, noise_sigma=noise_sigma, depth_dependent=True, depth_map=depth_map) output_image = tf.cast(sensor_img, tf.float32) # Now we deconvolve pad_width = output_image.shape.as_list()[1]//2 output_image = tf.pad(output_image, [[0,0],[pad_width, pad_width],[pad_width,pad_width],[0,0]], mode='SYMMETRIC') output_image = deconv.inverse_filter(output_image, output_image, optical_system.target_psf, init_gamma=init_gamma) output_image = output_image[:,pad_width:-pad_width,pad_width:-pad_width,:] optics.attach_summaries('output_image', output_image, image=True, log_image=False) return output_image
def forward_model(input_field): field = optics.height_map_element( input_field, wave_lengths=self.wave_lengths, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=None, height_tolerance=height_map_noise, refractive_idcs=self.refractive_idcs, name='height_map_optics') field = optics.circular_aperture(field) field = optics.propagate_fresnel( field, distance=self.distance, input_sample_interval=self.input_sample_interval, wave_lengths=self.wave_lengths) return field
def _build_graph(self, x_train, hm_reg_scale, hm_init_type='random_normal'): with tf.device('/device:GPU:0'): sensordims = (self.dim, self.dim) # Start with input image input_img = x_train / tf.reduce_sum(x_train) tf.summary.image('input_image', x_train) # fftshift(fft2(ifftshift( FIELD ))), zero-centered field = optics.fftshift2d_tf( optics.transp_fft2d(optics.ifftshift2d_tf(input_img))) # Build a phase mask, zero-centered height_map_initializer = tf.random_uniform_initializer( minval=0.999e-4, maxval=1.001e-4) # height_map_initializer=None pm = optics.height_map_element( [1, self.wave_resolution[0], self.wave_resolution[1], 1], wave_lengths=self.wave_length, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=height_map_initializer, name='phase_mask_height', refractive_index=self.n) # Get ATF and PSF otf = tf.ones( [1, self.wave_resolution[0], self.wave_resolution[1], 1]) otf = optics.circular_aperture(otf, max_val=self.r_NA) otf = pm(otf) psf = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(otf))) psf = optics.Sensor(input_is_intensities=False, resolution=sensordims)(psf) psf /= tf.reduce_sum(psf) # sum or max? psf = tf.cast(psf, tf.float32) optics.attach_img('recon_psf', psf) # Get the output image coherent = False if coherent: field = optics.circular_aperture(field, max_val=self.r_NA) field = pm(field) tf.summary.image('field', tf.square(tf.abs(field))) field = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(field))) output_img = optics.Sensor(input_is_intensities=False, resolution=(sensordims))(field) else: psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) output_img = tf.abs(optics.fft_conv2d(input_img, psf)) output_img = optics.Sensor(input_is_intensities=True, resolution=(sensordims))(output_img) output_img /= tf.reduce_sum(output_img) # sum or max? output_img = tf.cast(output_img, tf.float32) # output_img = tf.transpose(output_img, [1,2,0,3]) # (height, width, 1, 1) # Attach images to summary tf.summary.image('output_image', output_img) return output_img