예제 #1
0
    def _build_graph(self, x_train, hm_reg_scale, init_gamma, height_map_noise):
        input_img, depth_map = x_train

        with tf.device('/device:GPU:0'):
            height_map = optics.get_fourier_height_map(self.wave_resolution[0],
                                                       0.75,
                                                       height_map_regularizer=optics.laplace_l1_regularizer(
                                                       hm_reg_scale))

            optical_system = optics.SingleLensSetup(height_map=height_map,
                                        wave_resolution=self.wave_resolution,
                                        wave_lengths=self.wave_lengths,
                                        sensor_distance=self.sensor_distance,
                                        sensor_resolution=(self.patch_size, self.patch_size),
                                        input_sample_interval=self.sampling_interval,
                                        refractive_idcs=self.refractive_idcs,
                                        height_tolerance=height_map_noise,
                                        use_planar_incidence=False,
                                        depth_bins=self.depth_bins,
                                        upsample=False,
                                        psf_resolution=self.wave_resolution,
                                        target_distance=None)

            sensor_img = optical_system.get_sensor_img(input_img=input_img,
                                                       noise_sigma=None,
                                                       depth_dependent=True,
                                                       depth_map=depth_map)
            
            U_net = net.U_Net()
            output_image = U_net.build(sensor_img)
 
            optics.attach_summaries('output_image', output_image, image=True, log_image=False)

            return output_image
    def _build_graph(self, x_train, global_step, hm_reg_scale,
                     height_map_noise):
        '''
        Builds the graph for this model.

        :param x_train (graph node): input image.
        :param global_step: Global step variable (unused in this model)
        :param hm_reg_scale: Regularization coefficient for laplace l1 regularizer of phaseplate
        :param height_map_noise: Noise added to height map to account for manufacturing deficiencies
        :return: graph node for output image
        '''
        input_img = x_train

        with tf.device('/device:GPU:0'):
            # Input field is a planar wave.
            input_field = tf.ones((1, self.wave_res[0], self.wave_res[1],
                                   len(self.wave_lengths)))

            # Planar wave hits aperture: phase is shifted by phaseplate
            field = optics.height_map_element(
                input_field,
                wave_lengths=self.wave_lengths,
                height_map_regularizer=optics.laplace_l1_regularizer(
                    hm_reg_scale),
                height_map_initializer=None,
                height_tolerance=height_map_noise,
                refractive_idcs=self.refractive_idcs,
                name='height_map_optics')
            field = optics.circular_aperture(field)

            # Propagate field from aperture to sensor
            field = optics.propagate_fresnel(
                field,
                distance=self.sensor_distance,
                sampling_interval=self.sample_interval,
                wave_lengths=self.wave_lengths)

            # The psf is the intensities of the propagated field.
            psfs = optics.get_intensities(field)

            # Downsample psf to image resolution & normalize to sum to 1
            psfs = optics.area_downsampling_tf(psfs, self.patch_size)
            psfs = tf.div(psfs, tf.reduce_sum(psfs, axis=[1, 2],
                                              keepdims=True))
            optics.attach_summaries('PSF', psfs, image=True, log_image=True)

            # Image formation: PSF is convolved with input image
            psfs = tf.transpose(psfs, [1, 2, 0, 3])
            output_image = optics.img_psf_conv(input_img, psfs)
            output_image = tf.cast(output_image, tf.float32)
            optics.attach_summaries('output_image',
                                    output_image,
                                    image=True,
                                    log_image=False)

            output_image += tf.random_uniform(minval=0.001,
                                              maxval=0.02,
                                              shape=[])

            return output_image
예제 #3
0
    def _get_data_loss(self, model_output, ground_truth, margin=10):
        model_output = tf.cast(model_output, tf.float32)
        ground_truth = tf.cast(ground_truth, tf.float32)
        loss = tf.reduce_mean(
            tf.square(model_output - ground_truth)[:, margin:-margin,
                                                   margin:-margin, :])

        optics.attach_summaries('output_image',
                                model_output[:, margin:-margin,
                                             margin:-margin, :],
                                image=True,
                                log_image=False)
        return loss
예제 #4
0
    def _build_graph(self, x_train, hm_reg_scale, init_gamma,
                     height_map_noise):
        input_img, depth_map = x_train

        with tf.device('/device:GPU:0'):
            U_net = net.U_Net()
            output_image = U_net.build(input_img)

            optics.attach_summaries('output_image',
                                    output_image,
                                    image=True,
                                    log_image=False)

            return output_image
예제 #5
0
    def _build_graph(self, x_train, global_step, hm_reg_scale, init_gamma, height_map_noise, learned_target_depth, hm_init_type='random_normal'):
        input_img, depth_map = x_train

        with tf.device('/device:GPU:0'):
            with tf.variable_scope("optics"):
                height_map = optics.get_fourier_height_map(self.wave_resolution[0],
                                                           0.625,
                                                           height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale))

                target_depth_initializer = tf.constant_initializer(1.)
                target_depth = tf.get_variable(name="target_depth",
                                        shape=(),
                                        dtype=tf.float32,
                                        trainable=True,
                                        initializer=target_depth_initializer)
                target_depth = tf.square(target_depth)
                tf.summary.scalar('target_depth', target_depth)

                optical_system = optics.SingleLensSetup(height_map=height_map,
                                                 wave_resolution=self.wave_resolution,
                                                 wave_lengths=self.wave_lengths,
                                                 sensor_distance=self.distance,
                                                 sensor_resolution=(self.patch_size, self.patch_size),
                                                 input_sample_interval=self.input_sample_interval,
                                                 refractive_idcs=self.refractive_idcs,
                                                 height_tolerance=height_map_noise,
                                                 use_planar_incidence=False,
                                                 depth_bins=self.depth_bins,
                                                 upsample=False,
                                                 psf_resolution=self.wave_resolution,
                                                 target_distance=target_depth)

                noise_sigma = tf.random_uniform(minval=0.001, maxval=0.02, shape=[])
                sensor_img = optical_system.get_sensor_img(input_img=input_img,
                                                           noise_sigma=noise_sigma,
                                                           depth_dependent=True,
                                                           depth_map=depth_map)
                output_image = tf.cast(sensor_img, tf.float32)

            # Now we deconvolve
            pad_width = output_image.shape.as_list()[1]//2

            output_image = tf.pad(output_image, [[0,0],[pad_width, pad_width],[pad_width,pad_width],[0,0]], mode='SYMMETRIC')
            output_image = deconv.inverse_filter(output_image, output_image, optical_system.target_psf, init_gamma=init_gamma)
            output_image = output_image[:,pad_width:-pad_width,pad_width:-pad_width,:]

            optics.attach_summaries('output_image', output_image, image=True, log_image=False)

            return output_image
예제 #6
0
def optical_conv_layer(input_field,
                       hm_reg_scale,
                       r_NA,
                       n=1.48,
                       wavelength=532e-9,
                       activation=None,
                       coherent=False,
                       amplitude_mask=False,
                       zernike=False,
                       fourier=False,
                       binarymask=False,
                       n_modes=1024,
                       freq_range=.5,
                       initializer=None,
                       zernike_file='zernike_volume_256.npy',
                       binary_mask_np=None,
                       name='optical_conv'):

    dims = input_field.get_shape().as_list()
    with tf.variable_scope(name):

        if initializer is None:
            initializer = tf.random_uniform_initializer(minval=0.999e-4,
                                                        maxval=1.001e-4)

        if amplitude_mask:
            # Build an amplitude mask, zero-centered
            # amplitude_map_initializer=tf.random_uniform_initializer(minval=0.1e-3, maxval=.1e-2)
            mask = optics.amplitude_map_element(
                [1, dims[1], dims[2], 1],
                r_NA,
                amplitude_map_initializer=initializer,
                name='amplitude_mask')

        else:
            # Build a phase mask, zero-centered
            if zernike:
                zernike_modes = np.load(zernike_file)[:n_modes, :, :]
                zernike_modes = tf.expand_dims(zernike_modes, -1)
                zernike_modes = tf.image.resize_images(zernike_modes,
                                                       size=(dims[1], dims[2]))
                zernike_modes = tf.squeeze(zernike_modes, -1)
                mask = optics.zernike_element(zernike_modes,
                                              'zernike_element',
                                              wavelength,
                                              n,
                                              r_NA,
                                              zernike_initializer=initializer)
            elif fourier:
                mask = optics.fourier_element(
                    [1, dims[1], dims[2], 1],
                    'fourier_element',
                    wave_lengths=wavelength,
                    refractive_index=n,
                    frequency_range=freq_range,
                    height_map_regularizer=None,
                    height_tolerance=None,  # Default height tolerance is 2 nm.
                )

            else:
                # height_map_initializer=None
                mask = optics.height_map_element(
                    [1, dims[1], dims[2], 1],
                    wave_lengths=wavelength,
                    height_map_initializer=initializer,
                    #height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale),
                    name='phase_mask_height',
                    refractive_index=n)

        # Get ATF and PSF
        atf = tf.ones([1, dims[1], dims[2], 1])
        #zernike=True
        if zernike:
            atf = optics.circular_aperture(atf, max_val=r_NA)
        atf = mask(atf)

        # apply any additional binary amplitude mask [1, dim, dim, 1]
        if binarymask:
            binary_mask = tf.convert_to_tensor(binary_mask_np,
                                               dtype=tf.float32)
            binary_mask = tf.expand_dims(tf.expand_dims(binary_mask, 0), -1)
            # optics.attach_img('binary_mask', binary_mask)
            atf = atf * tf.cast(binary_mask, tf.complex128)
            optics.attach_img('atf', tf.abs(binary_mask))

        # psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(atf)))
        psfc = optics.fftshift2d_tf(
            optics.transp_ifft2d(optics.ifftshift2d_tf(atf)))
        psf = optics.Sensor(input_is_intensities=False,
                            resolution=(dims[1], dims[2]))(psfc)
        psf /= tf.reduce_sum(psf)  # conservation of energy
        psf = tf.cast(psf, tf.float32)
        optics.attach_summaries('psf', psf, True, True)

        # Get the output image
        if coherent:
            input_field = tf.cast(input_field, tf.complex128)
            # Zero-centered fft2 of input field
            field = optics.fftshift2d_tf(
                optics.transp_fft2d(optics.ifftshift2d_tf(input_field)))
            # field = optics.circular_aperture(field, max_val = r_NA)
            field = atf * field
            tf.summary.image('field', tf.log(tf.square(tf.abs(field))))
            output_img = optics.fftshift2d_tf(
                optics.transp_ifft2d(optics.ifftshift2d_tf(field)))
            output_img = optics.Sensor(input_is_intensities=False,
                                       resolution=(dims[1],
                                                   dims[2]))(output_img)
            # does this need padding as well?

            # psfc = tf.expand_dims(tf.expand_dims(tf.squeeze(psfc), -1), -1)
            # padamt = int(dims[1]/2)
            # output_img = optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psfc, padamt), adjoint=False)
            # output_img = fftunpad(output_img, padamt)
            # output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1],dims[2]))(output_img)

        else:
            psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
            # psf_flip = tf.reverse(psf,[0,1])
            # output_img = conv2d(input_field, psf)

            padamt = int(dims[1] / 2)
            output_img = tf.abs(
                optics.fft_conv2d(fftpad(input_field, padamt),
                                  fftpad_psf(psf, padamt),
                                  adjoint=False))
            output_img = fftunpad(output_img, padamt)

        # Apply nonlinear activation
        if activation is not None:
            output_img = activation(output_img)

        return output_img
예제 #7
0
def train(params, summary_every=100, save_every=2000, verbose=True):
    # Unpack params
    isNonNeg = params.get('isNonNeg', False)
    addBias = params.get('addBias', True)
    doLogTrans = params.get('logtrans', False)
    numIters = params.get('numIters', 1000)
    activation = params.get('activation', tf.nn.relu)

    # constraint helpers
    def nonneg(input_tensor):
        return tf.square(input_tensor) if isNonNeg else input_tensor

    sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True))

    # input placeholders
    classes = 9
    with tf.name_scope('input'):
        x = tf.placeholder(tf.float32, shape=[None, 784])
        y_ = tf.placeholder(tf.float32, shape=[None, classes])
        keep_prob = tf.placeholder(tf.float32)

        x_image = tf.reshape(x, [-1, 28, 28, 1])
        padamt = 0
        dim = 28+2*padamt
        paddings = tf.constant([[0, 0,], [padamt, padamt], [padamt, padamt], [0, 0]])
        x_image = tf.pad(x_image, paddings)
        tf.summary.image('input', x_image, 3)

    # build model
    with tf.device('/device:GPU:2'):

        doOpticalConv=False
        doConv=False
        if doConv:
            if doOpticalConv:
                doAmplitudeMask=True
                hm_reg_scale = 1e-2
                r_NA = 35
                h_conv1 = optical_conv_layer(x_image, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9,
                           activation=activation, amplitude_mask=doAmplitudeMask, name='opt_conv1')
                h_conv1 = tf.cast(h_conv1, dtype=tf.float32)
            else:
                W_conv1 = weight_variable([dim, dim, 1, 1], name='W_conv1')
                W_conv1 = nonneg(W_conv1)
                W_conv1_im = tf.expand_dims(tf.expand_dims(tf.squeeze(W_conv1), 0),3)
                optics.attach_summaries("W_conv1", W_conv1_im, image=True)

                # W_conv1 = weight_variable([12, 12, 1, 9])
                h_conv1 = activation(conv2d(x_image, (W_conv1)))            
            
            optics.attach_summaries("h_conv1", h_conv1, image=True)
            h_conv1_drop = tf.nn.dropout(h_conv1, keep_prob)
            
            # h_conv1_split = tf.split(h_conv1, 9, axis=3)
            # h_conv1_tiled = tf.concat([tf.concat(h_conv1_split[:3], axis=1), 
            #                            tf.concat(h_conv1_split[3:6], axis=1), 
            #                            tf.concat(h_conv1_split[6:9], axis=1)], axis=2)
            # tf.summary.image("h_conv1", h_conv1_tiled, 3)

            split_1d = tf.split(h_conv1_drop, num_or_size_splits=3, axis=1)
            h_conv1_split = tf.concat([tf.split(split_1d[0], num_or_size_splits=3, axis=2),
                                       tf.split(split_1d[1], num_or_size_splits=3, axis=2),
                                       tf.split(split_1d[2], num_or_size_splits=3, axis=2)], 0)
            y_out = tf.transpose(tf.reduce_max(h_conv1_split, axis=[2,3,4]))
        
        else:
            with tf.name_scope('fc'):
                fcsize = dim*dim
                W_fc1 = weight_variable([fcsize, classes], name='W_fc1')
                W_fc1 = nonneg(W_fc1)
                
                # visualize the FC weights
                W_fc1_split = tf.reshape(tf.transpose(W_fc1), [classes, 28, 28])
                W_fc1_split = tf.split(W_fc1_split, classes, axis=0)
                W_fc1_tiled = tf.concat([tf.concat(W_fc1_split[:3], axis=2),
                                         tf.concat(W_fc1_split[3:6], axis=2),
                                         tf.concat(W_fc1_split[6:9], axis=2)], axis=1)
                tf.summary.image("W_fc1", tf.expand_dims(W_fc1_tiled, 3))
        
                
                h_conv1_flat = tf.reshape(x_image, [-1, fcsize])
                y_out = (tf.matmul(h_conv1_flat, (W_fc1)))

        tf.summary.image('output', tf.reshape(y_out, [-1, 3, 3, 1]), 3)

    # loss, train, acc
    with tf.name_scope('cross_entropy'):
        total_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_out)
        mean_loss = tf.reduce_mean(total_loss)
        tf.summary.scalar('loss', mean_loss)

    train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(mean_loss)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

    losses = []

    # tensorboard setup
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
    test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')

    tf.global_variables_initializer().run()
    
    # add ops to save and restore all the variables
    saver = tf.train.Saver(max_to_keep=2)
    save_path = os.path.join(FLAGS.log_dir, 'model.ckpt')
    
    # MNIST feed dict
    def get_feed(train):
        if train:
            x, y = mnist.train.next_batch(50)
        else:
            x = mnist.test.images
            y = mnist.test.labels
            
        # remove "0"s
        indices = ~np.equal(y[:,0], 1)
        x_filt = np.squeeze(x[indices])
        y_filt = np.squeeze(y[indices,1:])
        
        return x_filt, y_filt
    
    # QuickDraw feed dict
    # train_data = np.load('/media/data/Datasets/quickdraw/split/all_train.npy')
    # test_data = np.load('/media/data/Datasets/quickdraw/split/all_test.npy')
    # def get_feed(train, batch_size=50):
    #     if train:
    #         idcs = np.random.randint(0, np.shape(train_data)[0], batch_size)
    #         x = train_data[idcs, :]
            
    #         categories = idcs//4000
    #         y = np.zeros((batch_size, classes))
    #         y[np.arange(batch_size), categories] = 1
          
    #    else:
    #        x = test_data 
    #        y = np.resize(np.equal(range(classes),0).astype(int),(100,classes))
    #        for i in range(1,classes):
    #            y = np.concatenate((y, np.resize(np.equal(range(classes),i).astype(int),(100,classes))), axis=0)
        
    #    return x, y
            
    x_test, y_test = get_feed(train=False)
    for i in range(FLAGS.num_iters):
        x_train, y_train = get_feed(train=True)
        _, loss = sess.run([train_step, mean_loss], feed_dict={x: x_train, y_: y_train, keep_prob: FLAGS.dropout})
        losses.append(loss)

        if i % summary_every == 0:
            train_summary, train_accuracy = sess.run([merged, accuracy],
                                                     feed_dict={
              x: x_train, y_: y_train, keep_prob: FLAGS.dropout})
            train_writer.add_summary(train_summary, i)
            
            test_summary, test_accuracy = sess.run([merged, accuracy],
                                                   feed_dict={
              x: x_test, y_: y_test, keep_prob: 1.0})
            # test_writer.add_summary(test_summary, i)
            
            if verbose:
                print('step %d: loss %g, train acc %g, test acc %g' %
                      (i, loss, train_accuracy, test_accuracy))
                
        if i % save_every == 0:
            print("Saving model...")
            saver.save(sess, save_path, global_step=i)

    test_acc = accuracy.eval(feed_dict={x: x_test, y_: y_test, keep_prob: 1.0})
    print('final step %d, train accuracy %g, test accuracy %g' %
          (i, train_accuracy, test_acc))
    #sess.close()

    train_writer.close()
    test_writer.close()
예제 #8
0
    def _build_graph(self,
                     x_train,
                     init_gamma,
                     hm_reg_scale,
                     noise_sigma,
                     height_map_noise,
                     hm_init_type='random_normal'):
        input_img = x_train
        print("build graph", input_img.get_shape())

        with tf.device('/device:GPU:0'):

            def forward_model(input_field):
                field = optics.height_map_element(
                    input_field,
                    wave_lengths=self.wave_lengths,
                    height_map_regularizer=optics.laplace_l1_regularizer(
                        hm_reg_scale),
                    height_map_initializer=None,
                    height_tolerance=height_map_noise,
                    refractive_idcs=self.refractive_idcs,
                    name='height_map_optics')
                field = optics.circular_aperture(field)
                field = optics.propagate_fresnel(
                    field,
                    distance=self.distance,
                    input_sample_interval=self.input_sample_interval,
                    wave_lengths=self.wave_lengths)
                return field

            optical_system = optics.OpticalSystem(
                forward_model,
                upsample=False,
                wave_resolution=self.wave_resolution,
                wave_lengths=self.wave_lengths,
                sensor_resolution=(self.patch_size, self.patch_size),
                psf_resolution=(self.patch_size,
                                self.patch_size),  # Equals wave resolution
                discretization_size=self.input_sample_interval,
                use_planar_incidence=True)

            if noise_sigma is None:
                noise_sigma = tf.random_uniform(minval=0.001,
                                                maxval=0.02,
                                                shape=[])

            sensor_img = optical_system.get_sensor_img(input_img=input_img,
                                                       noise_sigma=noise_sigma,
                                                       depth_dependent=False)
            output_image = tf.cast(sensor_img, tf.float32)

            # Now deconvolve
            pad_width = output_image.shape.as_list()[1] // 2

            output_image = tf.pad(output_image,
                                  [[0, 0], [pad_width, pad_width],
                                   [pad_width, pad_width], [0, 0]])
            output_image = deconv.inverse_filter(output_image,
                                                 output_image,
                                                 optical_system.psfs[0],
                                                 init_gamma=init_gamma)
            output_image = output_image[:, pad_width:-pad_width,
                                        pad_width:-pad_width, :]

            optics.attach_summaries('output_image',
                                    output_image,
                                    image=True,
                                    log_image=False)
            return output_image
예제 #9
0
def train(params,
          summary_every=100,
          print_every=250,
          save_every=1000,
          verbose=True):
    # Unpack params
    isNonNeg = params.get('isNonNeg', False)
    # addBias = params.get('addBias', True)
    numIters = params.get('numIters', 1000)
    activation = params.get('activation', tf.nn.relu)

    # constraint helpers
    def nonneg(input_tensor):
        return tf.square(input_tensor) if isNonNeg else input_tensor

    sess = tf.InteractiveSession(config=tf.ConfigProto(
        allow_soft_placement=True))

    # input placeholders
    classes = 9
    with tf.name_scope('input'):
        x = tf.placeholder(tf.float32, shape=[None, 784])
        y_ = tf.placeholder(tf.float32, shape=[None, classes])
        keep_prob = tf.placeholder(tf.float32)

        x_image = tf.reshape(x, [-1, 28, 28, 1])
        padamt = 28
        dim = 84
        paddings = tf.constant([[
            0,
            0,
        ], [padamt, padamt], [padamt, padamt], [0, 0]])
        x_image = tf.pad(x_image, paddings)
        x_image = tf.image.resize_nearest_neighbor(x_image, size=(dim, dim))
        tf.summary.image('input', x_image, 3)

    # build model
    if True:
        doOpticalConv = True
        if doOpticalConv:
            doAmplitudeMask = False
            hm_reg_scale = 1e-2
            r_NA = 35

            # initialize with optimized phase mask
            mask = np.load(
                'maskopt/opticalcorrelator_w-conv1_height-map-sqrt.npy')
            initializer = tf.constant_initializer(mask)
            # initializer=None

            h_conv1 = optical_conv_layer(x_image,
                                         hm_reg_scale,
                                         r_NA,
                                         n=1.48,
                                         wavelength=532e-9,
                                         activation=activation,
                                         amplitude_mask=doAmplitudeMask,
                                         initializer=initializer,
                                         name='opt_conv1')
            # h_conv2 = optical_conv_layer(h_conv1, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9,
            #            activation=activation, amplitude_mask=doAmplitudeMask, name='opt_conv2')
        else:
            conv1dim = dim
            W_conv1 = weight_variable([conv1dim, conv1dim, 1, 1],
                                      name='W_conv1')
            W_conv1_flip = tf.reverse(W_conv1, axis=[0, 1])
            # W_conv1 = weight_variable([12, 12, 1, 9])
            W_conv1_im = tf.expand_dims(tf.expand_dims(tf.squeeze(W_conv1), 0),
                                        3)
            optics.attach_summaries("W_conv1", W_conv1_im, image=True)
            h_conv1 = activation(conv2d(x_image, nonneg(W_conv1_flip)))

            # h_conv1_drop = tf.nn.dropout(h_conv1, keep_prob)

            # W_conv2 = weight_variable([48, 48, 1, 1], name='W_conv2')
            # W_conv2 = weight_variable([12, 12, 9, 9])
            # h_conv2 = activation(conv2d(h_conv1_drop, nonneg(W_conv2)))

        # h_conv1_split = tf.split(h_conv1, 9, axis=3)
        # h_conv1_tiled = tf.concat([tf.concat(h_conv1_split[:3], axis=1),
        #                            tf.concat(h_conv1_split[3:6], axis=1),
        #                            tf.concat(h_conv1_split[6:9], axis=1)], axis=2)
        # tf.summary.image("h_conv1", h_conv1_tiled, 3)

        # h_conv2_split = tf.split(h_conv2, 9, axis=3)
        # h_conv2_tiled = tf.concat([tf.concat(h_conv2_split[:3], axis=1),
        #                            tf.concat(h_conv2_split[3:6], axis=1),
        #                            tf.concat(h_conv2_split[6:9], axis=1)], axis=2)
        # tf.summary.image("h_conv2", h_conv2_tiled, 3)

        optics.attach_summaries("h_conv1", h_conv1, image=True)
        #optics.attach_summaries("h_conv2", h_conv2, image=True)
        # h_conv2 = x_image

        doFC = False
        if doFC:
            with tf.name_scope('fc'):
                h_conv1 = tf.cast(h_conv1, dtype=tf.float32)
                fcsize = dim * dim
                W_fc1 = weight_variable([fcsize, classes], name='W_fc1')
                h_conv1_flat = tf.reshape(h_conv1, [-1, fcsize])
                y_out = (tf.matmul(h_conv1_flat, nonneg(W_fc1)))

                # h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
                # W_fc2 = weight_variable([hidden_dim, 10])
                # y_out = tf.matmul(h_fc1_drop, nonneg(W_fc2))
        else:
            doConv2 = False
            if doConv2:
                if doOpticalConv:
                    h_conv2 = optical_conv_layer(h_conv1,
                                                 hm_reg_scale,
                                                 r_NA,
                                                 n=1.48,
                                                 wavelength=532e-9,
                                                 activation=activation,
                                                 name='opt_conv2')
                    h_conv2 = tf.cast(h_conv2, dtype=tf.float32)
                else:
                    W_conv2 = weight_variable([dim, dim, 1, 1])
                    W_conv2_flip = tf.reverse(W_conv2, axis=[0, 1])
                    W_conv2_im = tf.expand_dims(
                        tf.expand_dims(tf.squeeze(W_conv2), 0), 3)
                    optics.attach_summaries("W_conv2", W_conv2_im, image=True)
                    h_conv2 = activation(conv2d(h_conv1, nonneg(W_conv2_flip)))

                    W_conv3 = weight_variable([dim, dim, 1, 1])
                    W_conv3_flip = tf.reverse(W_conv3, axis=[0, 1])
                    W_conv3_im = tf.expand_dims(
                        tf.expand_dims(tf.squeeze(W_conv3), 0), 3)
                    optics.attach_summaries("W_conv3", W_conv3_im, image=True)
                    h_conv3 = activation(conv2d(h_conv2, nonneg(W_conv3_flip)))

                tf.summary.image("h_conv2", h_conv2)
                tf.summary.image("h_conv3", h_conv3)
                split_1d = tf.split(h_conv3, num_or_size_splits=3, axis=1)
            else:
                split_1d = tf.split(h_conv1, num_or_size_splits=3, axis=1)

            h_conv_split = tf.concat([
                tf.split(split_1d[0], num_or_size_splits=3, axis=2),
                tf.split(split_1d[1], num_or_size_splits=3, axis=2),
                tf.split(split_1d[2], num_or_size_splits=3, axis=2)
            ], 0)
            # h_conv2_split1, h_conv2_split2 = tf.split(h_conv2, num_or_size_splits=2, axis=1)
            # h_conv2_split = tf.concat([tf.split(h_conv2_split1, num_or_size_splits=5, axis=2),
            #                            tf.split(h_conv2_split2, num_or_size_splits=5, axis=2)], 0)
            y_out = tf.transpose(tf.reduce_max(h_conv_split, axis=[2, 3, 4]))

        tf.summary.image('output', tf.reshape(y_out, [-1, 3, 3, 1]), 3)

    # loss, train, acc
    with tf.name_scope('cross_entropy'):
        total_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_,
                                                             logits=y_out)
        mean_loss = tf.reduce_mean(total_loss)
        tf.summary.scalar('loss', mean_loss)

    # train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(mean_loss)
    train_step = tf.train.AdadeltaOptimizer(FLAGS.learning_rate,
                                            rho=1.0).minimize(mean_loss)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

    losses = []

    # tensorboard setup
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
    test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')

    tf.global_variables_initializer().run()

    # add ops to save and restore all the variables
    saver = tf.train.Saver(max_to_keep=2)
    save_path = os.path.join(FLAGS.log_dir, 'model.ckpt')

    def get_feed(train):
        if train:
            x, y = mnist.train.next_batch(50)
        else:
            x = mnist.test.images
            y = mnist.test.labels

        # remove "0"s
        indices = ~np.equal(y[:, 0], 1)
        x_filt = np.squeeze(x[indices])
        y_filt = np.squeeze(y[indices, 1:])

        return x_filt, y_filt

    x_test, y_test = get_feed(train=False)

    for i in range(FLAGS.num_iters):
        x_train, y_train = get_feed(train=True)
        _, loss, train_accuracy, train_summary = sess.run(
            [train_step, mean_loss, accuracy, merged],
            feed_dict={
                x: x_train,
                y_: y_train,
                keep_prob: FLAGS.dropout
            })
        losses.append(loss)

        if i % summary_every == 0:
            train_writer.add_summary(train_summary, i)

        if i > 0 and i % save_every == 0:
            # print("Saving model...")
            saver.save(sess, save_path, global_step=i)

            # test_summary, test_accuracy = sess.run([merged, accuracy],
            #                                        feed_dict={x: x_test, y_: y_test, keep_prob: 1.0})
            # test_writer.add_summary(test_summary, i)
            # if verbose:
            #     print('step %d: test acc %g' % (i, test_accuracy))

        if i % print_every == 0:
            if verbose:
                print('step %d: loss %g, train acc %g' %
                      (i, loss, train_accuracy))

    # test_acc = accuracy.eval(feed_dict={x: x_test, y_: y_test, keep_prob: 1.0})
    # print('final step %d, train accuracy %g, test accuracy %g' %
    #       (i, train_accuracy, test_acc))
    #sess.close()

    train_writer.close()
    test_writer.close()