def _build_graph(self, x_train, global_step, hm_reg_scale, height_map_noise): ''' Builds the graph for this model. :param x_train (graph node): input image. :param global_step: Global step variable (unused in this model) :param hm_reg_scale: Regularization coefficient for laplace l1 regularizer of phaseplate :param height_map_noise: Noise added to height map to account for manufacturing deficiencies :return: graph node for output image ''' input_img = x_train with tf.device('/device:GPU:0'): # Input field is a planar wave. input_field = tf.ones((1, self.wave_res[0], self.wave_res[1], len(self.wave_lengths))) # Planar wave hits aperture: phase is shifted by phaseplate field = optics.height_map_element( input_field, wave_lengths=self.wave_lengths, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=None, height_tolerance=height_map_noise, refractive_idcs=self.refractive_idcs, name='height_map_optics') field = optics.circular_aperture(field) # Propagate field from aperture to sensor field = optics.propagate_fresnel( field, distance=self.sensor_distance, sampling_interval=self.sample_interval, wave_lengths=self.wave_lengths) # The psf is the intensities of the propagated field. psfs = optics.get_intensities(field) # Downsample psf to image resolution & normalize to sum to 1 psfs = optics.area_downsampling_tf(psfs, self.patch_size) psfs = tf.div(psfs, tf.reduce_sum(psfs, axis=[1, 2], keepdims=True)) optics.attach_summaries('PSF', psfs, image=True, log_image=True) # Image formation: PSF is convolved with input image psfs = tf.transpose(psfs, [1, 2, 0, 3]) output_image = optics.img_psf_conv(input_img, psfs) output_image = tf.cast(output_image, tf.float32) optics.attach_summaries('output_image', output_image, image=True, log_image=False) output_image += tf.random_uniform(minval=0.001, maxval=0.02, shape=[]) return output_image
def forward_model(input_field): field = optics.height_map_element( input_field, wave_lengths=self.wave_lengths, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=None, height_tolerance=height_map_noise, refractive_idcs=self.refractive_idcs, name='height_map_optics') field = optics.circular_aperture(field) field = optics.propagate_fresnel( field, distance=self.distance, input_sample_interval=self.input_sample_interval, wave_lengths=self.wave_lengths) return field