def _build_graph(self, x_train, global_step, hm_reg_scale,
                     height_map_noise):
        '''
        Builds the graph for this model.

        :param x_train (graph node): input image.
        :param global_step: Global step variable (unused in this model)
        :param hm_reg_scale: Regularization coefficient for laplace l1 regularizer of phaseplate
        :param height_map_noise: Noise added to height map to account for manufacturing deficiencies
        :return: graph node for output image
        '''
        input_img = x_train

        with tf.device('/device:GPU:0'):
            # Input field is a planar wave.
            input_field = tf.ones((1, self.wave_res[0], self.wave_res[1],
                                   len(self.wave_lengths)))

            # Planar wave hits aperture: phase is shifted by phaseplate
            field = optics.height_map_element(
                input_field,
                wave_lengths=self.wave_lengths,
                height_map_regularizer=optics.laplace_l1_regularizer(
                    hm_reg_scale),
                height_map_initializer=None,
                height_tolerance=height_map_noise,
                refractive_idcs=self.refractive_idcs,
                name='height_map_optics')
            field = optics.circular_aperture(field)

            # Propagate field from aperture to sensor
            field = optics.propagate_fresnel(
                field,
                distance=self.sensor_distance,
                sampling_interval=self.sample_interval,
                wave_lengths=self.wave_lengths)

            # The psf is the intensities of the propagated field.
            psfs = optics.get_intensities(field)

            # Downsample psf to image resolution & normalize to sum to 1
            psfs = optics.area_downsampling_tf(psfs, self.patch_size)
            psfs = tf.div(psfs, tf.reduce_sum(psfs, axis=[1, 2],
                                              keepdims=True))
            optics.attach_summaries('PSF', psfs, image=True, log_image=True)

            # Image formation: PSF is convolved with input image
            psfs = tf.transpose(psfs, [1, 2, 0, 3])
            output_image = optics.img_psf_conv(input_img, psfs)
            output_image = tf.cast(output_image, tf.float32)
            optics.attach_summaries('output_image',
                                    output_image,
                                    image=True,
                                    log_image=False)

            output_image += tf.random_uniform(minval=0.001,
                                              maxval=0.02,
                                              shape=[])

            return output_image
Beispiel #2
0
 def forward_model(input_field):
     field = optics.height_map_element(
         input_field,
         wave_lengths=self.wave_lengths,
         height_map_regularizer=optics.laplace_l1_regularizer(
             hm_reg_scale),
         height_map_initializer=None,
         height_tolerance=height_map_noise,
         refractive_idcs=self.refractive_idcs,
         name='height_map_optics')
     field = optics.circular_aperture(field)
     field = optics.propagate_fresnel(
         field,
         distance=self.distance,
         input_sample_interval=self.input_sample_interval,
         wave_lengths=self.wave_lengths)
     return field