コード例 #1
0
    def _build_graph(self, x_train, global_step, hm_reg_scale,
                     height_map_noise):
        '''
        Builds the graph for this model.

        :param x_train (graph node): input image.
        :param global_step: Global step variable (unused in this model)
        :param hm_reg_scale: Regularization coefficient for laplace l1 regularizer of phaseplate
        :param height_map_noise: Noise added to height map to account for manufacturing deficiencies
        :return: graph node for output image
        '''
        input_img = x_train

        with tf.device('/device:GPU:0'):
            # Input field is a planar wave.
            input_field = tf.ones((1, self.wave_res[0], self.wave_res[1],
                                   len(self.wave_lengths)))

            # Planar wave hits aperture: phase is shifted by phaseplate
            field = optics.height_map_element(
                input_field,
                wave_lengths=self.wave_lengths,
                height_map_regularizer=optics.laplace_l1_regularizer(
                    hm_reg_scale),
                height_map_initializer=None,
                height_tolerance=height_map_noise,
                refractive_idcs=self.refractive_idcs,
                name='height_map_optics')
            field = optics.circular_aperture(field)

            # Propagate field from aperture to sensor
            field = optics.propagate_fresnel(
                field,
                distance=self.sensor_distance,
                sampling_interval=self.sample_interval,
                wave_lengths=self.wave_lengths)

            # The psf is the intensities of the propagated field.
            psfs = optics.get_intensities(field)

            # Downsample psf to image resolution & normalize to sum to 1
            psfs = optics.area_downsampling_tf(psfs, self.patch_size)
            psfs = tf.div(psfs, tf.reduce_sum(psfs, axis=[1, 2],
                                              keepdims=True))
            optics.attach_summaries('PSF', psfs, image=True, log_image=True)

            # Image formation: PSF is convolved with input image
            psfs = tf.transpose(psfs, [1, 2, 0, 3])
            output_image = optics.img_psf_conv(input_img, psfs)
            output_image = tf.cast(output_image, tf.float32)
            optics.attach_summaries('output_image',
                                    output_image,
                                    image=True,
                                    log_image=False)

            output_image += tf.random_uniform(minval=0.001,
                                              maxval=0.02,
                                              shape=[])

            return output_image
コード例 #2
0
ファイル: optics_unet_1.py プロジェクト: ZERO2ER0/deepoptics
    def _build_graph(self, x_train, hm_reg_scale, init_gamma, height_map_noise):
        input_img, depth_map = x_train

        with tf.device('/device:GPU:0'):
            height_map = optics.get_fourier_height_map(self.wave_resolution[0],
                                                       0.75,
                                                       height_map_regularizer=optics.laplace_l1_regularizer(
                                                       hm_reg_scale))

            optical_system = optics.SingleLensSetup(height_map=height_map,
                                        wave_resolution=self.wave_resolution,
                                        wave_lengths=self.wave_lengths,
                                        sensor_distance=self.sensor_distance,
                                        sensor_resolution=(self.patch_size, self.patch_size),
                                        input_sample_interval=self.sampling_interval,
                                        refractive_idcs=self.refractive_idcs,
                                        height_tolerance=height_map_noise,
                                        use_planar_incidence=False,
                                        depth_bins=self.depth_bins,
                                        upsample=False,
                                        psf_resolution=self.wave_resolution,
                                        target_distance=None)

            sensor_img = optical_system.get_sensor_img(input_img=input_img,
                                                       noise_sigma=None,
                                                       depth_dependent=True,
                                                       depth_map=depth_map)
            
            U_net = net.U_Net()
            output_image = U_net.build(sensor_img)
 
            optics.attach_summaries('output_image', output_image, image=True, log_image=False)

            return output_image
コード例 #3
0
ファイル: aedof_diffractive.py プロジェクト: jwgu/deepoptics
    def _build_graph(self, x_train, global_step, hm_reg_scale, init_gamma, height_map_noise, learned_target_depth, hm_init_type='random_normal'):
        input_img, depth_map = x_train

        with tf.device('/device:GPU:0'):
            with tf.variable_scope("optics"):
                height_map = optics.get_fourier_height_map(self.wave_resolution[0],
                                                           0.625,
                                                           height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale))

                target_depth_initializer = tf.constant_initializer(1.)
                target_depth = tf.get_variable(name="target_depth",
                                        shape=(),
                                        dtype=tf.float32,
                                        trainable=True,
                                        initializer=target_depth_initializer)
                target_depth = tf.square(target_depth)
                tf.summary.scalar('target_depth', target_depth)

                optical_system = optics.SingleLensSetup(height_map=height_map,
                                                 wave_resolution=self.wave_resolution,
                                                 wave_lengths=self.wave_lengths,
                                                 sensor_distance=self.distance,
                                                 sensor_resolution=(self.patch_size, self.patch_size),
                                                 input_sample_interval=self.input_sample_interval,
                                                 refractive_idcs=self.refractive_idcs,
                                                 height_tolerance=height_map_noise,
                                                 use_planar_incidence=False,
                                                 depth_bins=self.depth_bins,
                                                 upsample=False,
                                                 psf_resolution=self.wave_resolution,
                                                 target_distance=target_depth)

                noise_sigma = tf.random_uniform(minval=0.001, maxval=0.02, shape=[])
                sensor_img = optical_system.get_sensor_img(input_img=input_img,
                                                           noise_sigma=noise_sigma,
                                                           depth_dependent=True,
                                                           depth_map=depth_map)
                output_image = tf.cast(sensor_img, tf.float32)

            # Now we deconvolve
            pad_width = output_image.shape.as_list()[1]//2

            output_image = tf.pad(output_image, [[0,0],[pad_width, pad_width],[pad_width,pad_width],[0,0]], mode='SYMMETRIC')
            output_image = deconv.inverse_filter(output_image, output_image, optical_system.target_psf, init_gamma=init_gamma)
            output_image = output_image[:,pad_width:-pad_width,pad_width:-pad_width,:]

            optics.attach_summaries('output_image', output_image, image=True, log_image=False)

            return output_image
コード例 #4
0
 def forward_model(input_field):
     field = optics.height_map_element(
         input_field,
         wave_lengths=self.wave_lengths,
         height_map_regularizer=optics.laplace_l1_regularizer(
             hm_reg_scale),
         height_map_initializer=None,
         height_tolerance=height_map_noise,
         refractive_idcs=self.refractive_idcs,
         name='height_map_optics')
     field = optics.circular_aperture(field)
     field = optics.propagate_fresnel(
         field,
         distance=self.distance,
         input_sample_interval=self.input_sample_interval,
         wave_lengths=self.wave_lengths)
     return field
コード例 #5
0
    def _build_graph(self,
                     x_train,
                     hm_reg_scale,
                     hm_init_type='random_normal'):
        with tf.device('/device:GPU:0'):
            sensordims = (self.dim, self.dim)
            # Start with input image
            input_img = x_train / tf.reduce_sum(x_train)
            tf.summary.image('input_image', x_train)

            # fftshift(fft2(ifftshift( FIELD ))), zero-centered
            field = optics.fftshift2d_tf(
                optics.transp_fft2d(optics.ifftshift2d_tf(input_img)))

            # Build a phase mask, zero-centered
            height_map_initializer = tf.random_uniform_initializer(
                minval=0.999e-4, maxval=1.001e-4)
            # height_map_initializer=None
            pm = optics.height_map_element(
                [1, self.wave_resolution[0], self.wave_resolution[1], 1],
                wave_lengths=self.wave_length,
                height_map_regularizer=optics.laplace_l1_regularizer(
                    hm_reg_scale),
                height_map_initializer=height_map_initializer,
                name='phase_mask_height',
                refractive_index=self.n)

            # Get ATF and PSF
            otf = tf.ones(
                [1, self.wave_resolution[0], self.wave_resolution[1], 1])
            otf = optics.circular_aperture(otf, max_val=self.r_NA)
            otf = pm(otf)
            psf = optics.fftshift2d_tf(
                optics.transp_ifft2d(optics.ifftshift2d_tf(otf)))
            psf = optics.Sensor(input_is_intensities=False,
                                resolution=sensordims)(psf)
            psf /= tf.reduce_sum(psf)  # sum or max?
            psf = tf.cast(psf, tf.float32)
            optics.attach_img('recon_psf', psf)

            # Get the output image
            coherent = False
            if coherent:
                field = optics.circular_aperture(field, max_val=self.r_NA)
                field = pm(field)
                tf.summary.image('field', tf.square(tf.abs(field)))
                field = optics.fftshift2d_tf(
                    optics.transp_ifft2d(optics.ifftshift2d_tf(field)))

                output_img = optics.Sensor(input_is_intensities=False,
                                           resolution=(sensordims))(field)
            else:
                psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
                output_img = tf.abs(optics.fft_conv2d(input_img, psf))
                output_img = optics.Sensor(input_is_intensities=True,
                                           resolution=(sensordims))(output_img)

            output_img /= tf.reduce_sum(output_img)  # sum or max?
            output_img = tf.cast(output_img, tf.float32)
            # output_img = tf.transpose(output_img, [1,2,0,3]) # (height, width, 1, 1)

            # Attach images to summary
            tf.summary.image('output_image', output_img)
            return output_img