def _build_graph(self, x_train, global_step, hm_reg_scale,
                     height_map_noise):
        '''
        Builds the graph for this model.

        :param x_train (graph node): input image.
        :param global_step: Global step variable (unused in this model)
        :param hm_reg_scale: Regularization coefficient for laplace l1 regularizer of phaseplate
        :param height_map_noise: Noise added to height map to account for manufacturing deficiencies
        :return: graph node for output image
        '''
        input_img = x_train

        with tf.device('/device:GPU:0'):
            # Input field is a planar wave.
            input_field = tf.ones((1, self.wave_res[0], self.wave_res[1],
                                   len(self.wave_lengths)))

            # Planar wave hits aperture: phase is shifted by phaseplate
            field = optics.height_map_element(
                input_field,
                wave_lengths=self.wave_lengths,
                height_map_regularizer=optics.laplace_l1_regularizer(
                    hm_reg_scale),
                height_map_initializer=None,
                height_tolerance=height_map_noise,
                refractive_idcs=self.refractive_idcs,
                name='height_map_optics')
            field = optics.circular_aperture(field)

            # Propagate field from aperture to sensor
            field = optics.propagate_fresnel(
                field,
                distance=self.sensor_distance,
                sampling_interval=self.sample_interval,
                wave_lengths=self.wave_lengths)

            # The psf is the intensities of the propagated field.
            psfs = optics.get_intensities(field)

            # Downsample psf to image resolution & normalize to sum to 1
            psfs = optics.area_downsampling_tf(psfs, self.patch_size)
            psfs = tf.div(psfs, tf.reduce_sum(psfs, axis=[1, 2],
                                              keepdims=True))
            optics.attach_summaries('PSF', psfs, image=True, log_image=True)

            # Image formation: PSF is convolved with input image
            psfs = tf.transpose(psfs, [1, 2, 0, 3])
            output_image = optics.img_psf_conv(input_img, psfs)
            output_image = tf.cast(output_image, tf.float32)
            optics.attach_summaries('output_image',
                                    output_image,
                                    image=True,
                                    log_image=False)

            output_image += tf.random_uniform(minval=0.001,
                                              maxval=0.02,
                                              shape=[])

            return output_image
예제 #2
0
 def forward_model(input_field):
     field = optics.height_map_element(
         input_field,
         wave_lengths=self.wave_lengths,
         height_map_regularizer=optics.laplace_l1_regularizer(
             hm_reg_scale),
         height_map_initializer=None,
         height_tolerance=height_map_noise,
         refractive_idcs=self.refractive_idcs,
         name='height_map_optics')
     field = optics.circular_aperture(field)
     field = optics.propagate_fresnel(
         field,
         distance=self.distance,
         input_sample_interval=self.input_sample_interval,
         wave_lengths=self.wave_lengths)
     return field
예제 #3
0
    def _optical_conv_layer(self, input_field, hm_reg_scale, activation=None, coherent=False,
                            name='optical_conv'):
        
        with tf.variable_scope(name):
            sensordims = self.wave_resolution
            input_field = tf.cast(input_field, tf.complex128)
            
            # Zero-centered fft2 of input field
            field = optics.fftshift2d_tf(optics.transp_fft2d(optics.ifftshift2d_tf(input_field)))

            # Build a phase mask, zero-centered
            height_map_initializer=tf.random_uniform_initializer(minval=0.999e-4, maxval=1.001e-4)
            # height_map_initializer=None
            pm = optics.height_map_element([1,self.wave_resolution[0],self.wave_resolution[1],1],
                                             wave_lengths=self.wavelength,
                                             height_map_initializer=height_map_initializer,
                                             name='phase_mask_height',
                                             refractive_index=self.n)
            # height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale),
            
            # Get ATF and PSF
            otf = tf.ones([1,self.wave_resolution[0],self.wave_resolution[1],1])
            otf = optics.circular_aperture(otf, max_val = self.r_NA)
            otf = pm(otf)
            psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(otf)))
            psf = optics.Sensor(input_is_intensities=False, resolution=sensordims)(psf)
            psf /= tf.reduce_sum(psf) # conservation of energy
            psf = tf.cast(psf, tf.float32)
            optics.attach_img('psf', psf)
            
            # Get the output image
            if coherent:
                field = optics.circular_aperture(field, max_val = self.r_NA)
                field = pm(field)
                tf.summary.image('field', tf.square(tf.abs(field)))
                output_img = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(field)))
            else:
                psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
                output_img = tf.abs(optics.fft_conv2d(input_field, psf))
    
            # Apply nonlinear activation
            if activation is not None:
                output_img = activation(output_img)
                
            return output_img
예제 #4
0
    def _build_graph(self,
                     x_train,
                     hm_reg_scale,
                     hm_init_type='random_normal'):
        with tf.device('/device:GPU:0'):
            sensordims = (self.dim, self.dim)
            # Start with input image
            input_img = x_train / tf.reduce_sum(x_train)
            tf.summary.image('input_image', x_train)

            # fftshift(fft2(ifftshift( FIELD ))), zero-centered
            field = optics.fftshift2d_tf(
                optics.transp_fft2d(optics.ifftshift2d_tf(input_img)))

            # Build a phase mask, zero-centered
            height_map_initializer = tf.random_uniform_initializer(
                minval=0.999e-4, maxval=1.001e-4)
            # height_map_initializer=None
            pm = optics.height_map_element(
                [1, self.wave_resolution[0], self.wave_resolution[1], 1],
                wave_lengths=self.wave_length,
                height_map_regularizer=optics.laplace_l1_regularizer(
                    hm_reg_scale),
                height_map_initializer=height_map_initializer,
                name='phase_mask_height',
                refractive_index=self.n)

            # Get ATF and PSF
            otf = tf.ones(
                [1, self.wave_resolution[0], self.wave_resolution[1], 1])
            otf = optics.circular_aperture(otf, max_val=self.r_NA)
            otf = pm(otf)
            psf = optics.fftshift2d_tf(
                optics.transp_ifft2d(optics.ifftshift2d_tf(otf)))
            psf = optics.Sensor(input_is_intensities=False,
                                resolution=sensordims)(psf)
            psf /= tf.reduce_sum(psf)  # sum or max?
            psf = tf.cast(psf, tf.float32)
            optics.attach_img('recon_psf', psf)

            # Get the output image
            coherent = False
            if coherent:
                field = optics.circular_aperture(field, max_val=self.r_NA)
                field = pm(field)
                tf.summary.image('field', tf.square(tf.abs(field)))
                field = optics.fftshift2d_tf(
                    optics.transp_ifft2d(optics.ifftshift2d_tf(field)))

                output_img = optics.Sensor(input_is_intensities=False,
                                           resolution=(sensordims))(field)
            else:
                psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
                output_img = tf.abs(optics.fft_conv2d(input_img, psf))
                output_img = optics.Sensor(input_is_intensities=True,
                                           resolution=(sensordims))(output_img)

            output_img /= tf.reduce_sum(output_img)  # sum or max?
            output_img = tf.cast(output_img, tf.float32)
            # output_img = tf.transpose(output_img, [1,2,0,3]) # (height, width, 1, 1)

            # Attach images to summary
            tf.summary.image('output_image', output_img)
            return output_img
예제 #5
0
def optical_conv_layer(input_field,
                       hm_reg_scale,
                       r_NA,
                       n=1.48,
                       wavelength=532e-9,
                       activation=None,
                       coherent=False,
                       amplitude_mask=False,
                       zernike=False,
                       fourier=False,
                       binarymask=False,
                       n_modes=1024,
                       freq_range=.5,
                       initializer=None,
                       zernike_file='zernike_volume_256.npy',
                       binary_mask_np=None,
                       name='optical_conv'):

    dims = input_field.get_shape().as_list()
    with tf.variable_scope(name):

        if initializer is None:
            initializer = tf.random_uniform_initializer(minval=0.999e-4,
                                                        maxval=1.001e-4)

        if amplitude_mask:
            # Build an amplitude mask, zero-centered
            # amplitude_map_initializer=tf.random_uniform_initializer(minval=0.1e-3, maxval=.1e-2)
            mask = optics.amplitude_map_element(
                [1, dims[1], dims[2], 1],
                r_NA,
                amplitude_map_initializer=initializer,
                name='amplitude_mask')

        else:
            # Build a phase mask, zero-centered
            if zernike:
                zernike_modes = np.load(zernike_file)[:n_modes, :, :]
                zernike_modes = tf.expand_dims(zernike_modes, -1)
                zernike_modes = tf.image.resize_images(zernike_modes,
                                                       size=(dims[1], dims[2]))
                zernike_modes = tf.squeeze(zernike_modes, -1)
                mask = optics.zernike_element(zernike_modes,
                                              'zernike_element',
                                              wavelength,
                                              n,
                                              r_NA,
                                              zernike_initializer=initializer)
            elif fourier:
                mask = optics.fourier_element(
                    [1, dims[1], dims[2], 1],
                    'fourier_element',
                    wave_lengths=wavelength,
                    refractive_index=n,
                    frequency_range=freq_range,
                    height_map_regularizer=None,
                    height_tolerance=None,  # Default height tolerance is 2 nm.
                )

            else:
                # height_map_initializer=None
                mask = optics.height_map_element(
                    [1, dims[1], dims[2], 1],
                    wave_lengths=wavelength,
                    height_map_initializer=initializer,
                    #height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale),
                    name='phase_mask_height',
                    refractive_index=n)

        # Get ATF and PSF
        atf = tf.ones([1, dims[1], dims[2], 1])
        #zernike=True
        if zernike:
            atf = optics.circular_aperture(atf, max_val=r_NA)
        atf = mask(atf)

        # apply any additional binary amplitude mask [1, dim, dim, 1]
        if binarymask:
            binary_mask = tf.convert_to_tensor(binary_mask_np,
                                               dtype=tf.float32)
            binary_mask = tf.expand_dims(tf.expand_dims(binary_mask, 0), -1)
            # optics.attach_img('binary_mask', binary_mask)
            atf = atf * tf.cast(binary_mask, tf.complex128)
            optics.attach_img('atf', tf.abs(binary_mask))

        # psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(atf)))
        psfc = optics.fftshift2d_tf(
            optics.transp_ifft2d(optics.ifftshift2d_tf(atf)))
        psf = optics.Sensor(input_is_intensities=False,
                            resolution=(dims[1], dims[2]))(psfc)
        psf /= tf.reduce_sum(psf)  # conservation of energy
        psf = tf.cast(psf, tf.float32)
        optics.attach_summaries('psf', psf, True, True)

        # Get the output image
        if coherent:
            input_field = tf.cast(input_field, tf.complex128)
            # Zero-centered fft2 of input field
            field = optics.fftshift2d_tf(
                optics.transp_fft2d(optics.ifftshift2d_tf(input_field)))
            # field = optics.circular_aperture(field, max_val = r_NA)
            field = atf * field
            tf.summary.image('field', tf.log(tf.square(tf.abs(field))))
            output_img = optics.fftshift2d_tf(
                optics.transp_ifft2d(optics.ifftshift2d_tf(field)))
            output_img = optics.Sensor(input_is_intensities=False,
                                       resolution=(dims[1],
                                                   dims[2]))(output_img)
            # does this need padding as well?

            # psfc = tf.expand_dims(tf.expand_dims(tf.squeeze(psfc), -1), -1)
            # padamt = int(dims[1]/2)
            # output_img = optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psfc, padamt), adjoint=False)
            # output_img = fftunpad(output_img, padamt)
            # output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1],dims[2]))(output_img)

        else:
            psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
            # psf_flip = tf.reverse(psf,[0,1])
            # output_img = conv2d(input_field, psf)

            padamt = int(dims[1] / 2)
            output_img = tf.abs(
                optics.fft_conv2d(fftpad(input_field, padamt),
                                  fftpad_psf(psf, padamt),
                                  adjoint=False))
            output_img = fftunpad(output_img, padamt)

        # Apply nonlinear activation
        if activation is not None:
            output_img = activation(output_img)

        return output_img