Example #1
0
    def _get_training_queue(self, batch_size, num_threads=4):
        dim = self.dim

        file_list = tf.matching_files(
            '/media/data/onn/quickdraw16_192/im_*.png')
        filename_queue = tf.train.string_input_producer(file_list)

        image_reader = tf.WholeFileReader()

        _, image_file = image_reader.read(filename_queue)
        image = tf.image.decode_png(image_file, channels=1, dtype=tf.uint8)
        image = tf.cast(image, tf.float32)  # Shape [height, width, 1]
        image = tf.expand_dims(image, 0)
        image /= 255.

        # Get the ratio of the patch size to the smallest side of the image
        img_height_width = tf.cast(tf.shape(image)[1:3], tf.float32)
        size_ratio = dim / tf.reduce_min(img_height_width)

        # Extract a glimpse from the image
        #offset_center = tf.random_uniform([1,2], minval=0.0 + size_ratio/2, maxval=1.0-size_ratio/2, dtype=tf.float32)
        offset_center = tf.random_uniform([1, 2],
                                          minval=0,
                                          maxval=0,
                                          dtype=tf.float32)
        offset_center = offset_center * img_height_width

        image = tf.image.extract_glimpse(image,
                                         size=[dim, dim],
                                         offsets=offset_center,
                                         centered=True,
                                         normalized=False)
        image = tf.squeeze(image, 0)

        convolved_image = tf.expand_dims(image, 0)
        psf = tf.convert_to_tensor(np.load(self.psf_file), tf.float32)
        psf /= tf.reduce_sum(psf)
        optics.attach_img(
            'gt_psf', tf.expand_dims(tf.expand_dims(tf.squeeze(psf), 0), -1))

        psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
        # psf = tf.transpose(psf, [1,2,0,3])

        pad = int(dim / 2)
        convolved_image = tf.abs(
            optics.fft_conv2d(fftpad(convolved_image, pad),
                              fftpad_psf(psf, pad),
                              adjoint=True))
        convolved_image = fftunpad(convolved_image, pad)
        convolved_image = tf.squeeze(convolved_image, axis=0)
        convolved_image /= tf.reduce_sum(convolved_image)

        image_batch, convolved_img_batch = tf.train.batch(
            [image, convolved_image],
            shapes=[[dim, dim, 1], [dim, dim, 1]],
            batch_size=batch_size,
            num_threads=4,
            capacity=4 * batch_size)

        return image_batch, convolved_img_batch
Example #2
0
    def _optical_conv_layer(self, input_field, hm_reg_scale, activation=None, coherent=False,
                            name='optical_conv'):
        
        with tf.variable_scope(name):
            sensordims = self.wave_resolution
            input_field = tf.cast(input_field, tf.complex128)
            
            # Zero-centered fft2 of input field
            field = optics.fftshift2d_tf(optics.transp_fft2d(optics.ifftshift2d_tf(input_field)))

            # Build a phase mask, zero-centered
            height_map_initializer=tf.random_uniform_initializer(minval=0.999e-4, maxval=1.001e-4)
            # height_map_initializer=None
            pm = optics.height_map_element([1,self.wave_resolution[0],self.wave_resolution[1],1],
                                             wave_lengths=self.wavelength,
                                             height_map_initializer=height_map_initializer,
                                             name='phase_mask_height',
                                             refractive_index=self.n)
            # height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale),
            
            # Get ATF and PSF
            otf = tf.ones([1,self.wave_resolution[0],self.wave_resolution[1],1])
            otf = optics.circular_aperture(otf, max_val = self.r_NA)
            otf = pm(otf)
            psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(otf)))
            psf = optics.Sensor(input_is_intensities=False, resolution=sensordims)(psf)
            psf /= tf.reduce_sum(psf) # conservation of energy
            psf = tf.cast(psf, tf.float32)
            optics.attach_img('psf', psf)
            
            # Get the output image
            if coherent:
                field = optics.circular_aperture(field, max_val = self.r_NA)
                field = pm(field)
                tf.summary.image('field', tf.square(tf.abs(field)))
                output_img = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(field)))
            else:
                psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
                output_img = tf.abs(optics.fft_conv2d(input_field, psf))
    
            # Apply nonlinear activation
            if activation is not None:
                output_img = activation(output_img)
                
            return output_img
    def _build_graph(self,
                     x_train,
                     hm_reg_scale,
                     hm_init_type='random_normal'):
        with tf.device('/device:GPU:0'):
            sensordims = (self.dim, self.dim)
            # Start with input image
            input_img = x_train / tf.reduce_sum(x_train)
            tf.summary.image('input_image', x_train)

            # fftshift(fft2(ifftshift( FIELD ))), zero-centered
            field = optics.fftshift2d_tf(
                optics.transp_fft2d(optics.ifftshift2d_tf(input_img)))

            # Build a phase mask, zero-centered
            height_map_initializer = tf.random_uniform_initializer(
                minval=0.999e-4, maxval=1.001e-4)
            # height_map_initializer=None
            pm = optics.height_map_element(
                [1, self.wave_resolution[0], self.wave_resolution[1], 1],
                wave_lengths=self.wave_length,
                height_map_regularizer=optics.laplace_l1_regularizer(
                    hm_reg_scale),
                height_map_initializer=height_map_initializer,
                name='phase_mask_height',
                refractive_index=self.n)

            # Get ATF and PSF
            otf = tf.ones(
                [1, self.wave_resolution[0], self.wave_resolution[1], 1])
            otf = optics.circular_aperture(otf, max_val=self.r_NA)
            otf = pm(otf)
            psf = optics.fftshift2d_tf(
                optics.transp_ifft2d(optics.ifftshift2d_tf(otf)))
            psf = optics.Sensor(input_is_intensities=False,
                                resolution=sensordims)(psf)
            psf /= tf.reduce_sum(psf)  # sum or max?
            psf = tf.cast(psf, tf.float32)
            optics.attach_img('recon_psf', psf)

            # Get the output image
            coherent = False
            if coherent:
                field = optics.circular_aperture(field, max_val=self.r_NA)
                field = pm(field)
                tf.summary.image('field', tf.square(tf.abs(field)))
                field = optics.fftshift2d_tf(
                    optics.transp_ifft2d(optics.ifftshift2d_tf(field)))

                output_img = optics.Sensor(input_is_intensities=False,
                                           resolution=(sensordims))(field)
            else:
                psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
                output_img = tf.abs(optics.fft_conv2d(input_img, psf))
                output_img = optics.Sensor(input_is_intensities=True,
                                           resolution=(sensordims))(output_img)

            output_img /= tf.reduce_sum(output_img)  # sum or max?
            output_img = tf.cast(output_img, tf.float32)
            # output_img = tf.transpose(output_img, [1,2,0,3]) # (height, width, 1, 1)

            # Attach images to summary
            tf.summary.image('output_image', output_img)
            return output_img
Example #4
0
def optical_conv_layer(input_field,
                       hm_reg_scale,
                       r_NA,
                       n=1.48,
                       wavelength=532e-9,
                       activation=None,
                       coherent=False,
                       amplitude_mask=False,
                       zernike=False,
                       fourier=False,
                       binarymask=False,
                       n_modes=1024,
                       freq_range=.5,
                       initializer=None,
                       zernike_file='zernike_volume_256.npy',
                       binary_mask_np=None,
                       name='optical_conv'):

    dims = input_field.get_shape().as_list()
    with tf.variable_scope(name):

        if initializer is None:
            initializer = tf.random_uniform_initializer(minval=0.999e-4,
                                                        maxval=1.001e-4)

        if amplitude_mask:
            # Build an amplitude mask, zero-centered
            # amplitude_map_initializer=tf.random_uniform_initializer(minval=0.1e-3, maxval=.1e-2)
            mask = optics.amplitude_map_element(
                [1, dims[1], dims[2], 1],
                r_NA,
                amplitude_map_initializer=initializer,
                name='amplitude_mask')

        else:
            # Build a phase mask, zero-centered
            if zernike:
                zernike_modes = np.load(zernike_file)[:n_modes, :, :]
                zernike_modes = tf.expand_dims(zernike_modes, -1)
                zernike_modes = tf.image.resize_images(zernike_modes,
                                                       size=(dims[1], dims[2]))
                zernike_modes = tf.squeeze(zernike_modes, -1)
                mask = optics.zernike_element(zernike_modes,
                                              'zernike_element',
                                              wavelength,
                                              n,
                                              r_NA,
                                              zernike_initializer=initializer)
            elif fourier:
                mask = optics.fourier_element(
                    [1, dims[1], dims[2], 1],
                    'fourier_element',
                    wave_lengths=wavelength,
                    refractive_index=n,
                    frequency_range=freq_range,
                    height_map_regularizer=None,
                    height_tolerance=None,  # Default height tolerance is 2 nm.
                )

            else:
                # height_map_initializer=None
                mask = optics.height_map_element(
                    [1, dims[1], dims[2], 1],
                    wave_lengths=wavelength,
                    height_map_initializer=initializer,
                    #height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale),
                    name='phase_mask_height',
                    refractive_index=n)

        # Get ATF and PSF
        atf = tf.ones([1, dims[1], dims[2], 1])
        #zernike=True
        if zernike:
            atf = optics.circular_aperture(atf, max_val=r_NA)
        atf = mask(atf)

        # apply any additional binary amplitude mask [1, dim, dim, 1]
        if binarymask:
            binary_mask = tf.convert_to_tensor(binary_mask_np,
                                               dtype=tf.float32)
            binary_mask = tf.expand_dims(tf.expand_dims(binary_mask, 0), -1)
            # optics.attach_img('binary_mask', binary_mask)
            atf = atf * tf.cast(binary_mask, tf.complex128)
            optics.attach_img('atf', tf.abs(binary_mask))

        # psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(atf)))
        psfc = optics.fftshift2d_tf(
            optics.transp_ifft2d(optics.ifftshift2d_tf(atf)))
        psf = optics.Sensor(input_is_intensities=False,
                            resolution=(dims[1], dims[2]))(psfc)
        psf /= tf.reduce_sum(psf)  # conservation of energy
        psf = tf.cast(psf, tf.float32)
        optics.attach_summaries('psf', psf, True, True)

        # Get the output image
        if coherent:
            input_field = tf.cast(input_field, tf.complex128)
            # Zero-centered fft2 of input field
            field = optics.fftshift2d_tf(
                optics.transp_fft2d(optics.ifftshift2d_tf(input_field)))
            # field = optics.circular_aperture(field, max_val = r_NA)
            field = atf * field
            tf.summary.image('field', tf.log(tf.square(tf.abs(field))))
            output_img = optics.fftshift2d_tf(
                optics.transp_ifft2d(optics.ifftshift2d_tf(field)))
            output_img = optics.Sensor(input_is_intensities=False,
                                       resolution=(dims[1],
                                                   dims[2]))(output_img)
            # does this need padding as well?

            # psfc = tf.expand_dims(tf.expand_dims(tf.squeeze(psfc), -1), -1)
            # padamt = int(dims[1]/2)
            # output_img = optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psfc, padamt), adjoint=False)
            # output_img = fftunpad(output_img, padamt)
            # output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1],dims[2]))(output_img)

        else:
            psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
            # psf_flip = tf.reverse(psf,[0,1])
            # output_img = conv2d(input_field, psf)

            padamt = int(dims[1] / 2)
            output_img = tf.abs(
                optics.fft_conv2d(fftpad(input_field, padamt),
                                  fftpad_psf(psf, padamt),
                                  adjoint=False))
            output_img = fftunpad(output_img, padamt)

        # Apply nonlinear activation
        if activation is not None:
            output_img = activation(output_img)

        return output_img
def train(params,
          summary_every=100,
          print_every=250,
          save_every=1000,
          verbose=True):
    # Unpack params
    wavelength = params.get('wavelength', 532e-9)
    isNonNeg = params.get('isNonNeg', False)
    numIters = params.get('numIters', 1000)
    activation = params.get('activation', tf.nn.relu)
    opt_type = params.get('opt_type', 'ADAM')

    # switches
    doMultichannelConv = params.get('doMultichannelConv', False)
    doMean = params.get('doMean', False)
    doOpticalConv = params.get('doOpticalConv', True)
    doAmplitudeMask = params.get('doAmplitudeMask', False)
    doZernike = params.get('doZernike', False)
    doFC = params.get('doFC', False)
    doConv1 = params.get('doConv1', True)
    doConv2 = params.get('doConv2', True)
    doConv3 = params.get('doConv3', False)
    doNonnegReg = params.get('doNonnegReg', False)
    doOptNeg = params.get('doOptNeg', False)
    doTiledConv = params.get('doTiledConv', False)

    z_modes = params.get('z_modes', 1024)
    convdim1 = params.get('convdim1', 100)
    convdim2 = params.get('convdim2', 100)
    convdim3 = params.get('convdim3', 100)

    depth1 = params.get('depth1', 3)
    depth2 = params.get('depth2', 3)
    depth3 = params.get('depth3', 3)

    padamt = params.get('padamt', 0)
    dim = params.get('dim', 60)

    buff = params.get('buff', 4)
    rows = params.get('rows', 4)
    cols = params.get('cols', 4)

    # constraint helpers
    def nonneg(input_tensor):
        return tf.abs(input_tensor) if isNonNeg else input_tensor

    def vis_weights(W_conv, depth, buff, rows, cols, name):
        kernel_list = tf.split(tf.transpose(W_conv, [2, 0, 1, 3]),
                               depth,
                               axis=3)
        kernels_pad = [
            tf.pad(kernel,
                   [[0, 0], [buff, buff], [buff + 4, buff + 4], [0, 0]])
            for kernel in kernel_list
        ]
        W_conv_tiled = tf.concat([
            tf.concat(kernels_pad[i * cols:(i + 1) * cols], axis=2)
            for i in range(rows)
        ],
                                 axis=1)
        tf.summary.image(name, W_conv_tiled, 3)

    def vis_h(h_conv, depth, rows, cols, name):
        # this was for viewing multichannel convolution
        h_conv_split = tf.split(h_conv, depth, axis=3)
        h_conv_tiled = tf.concat([
            tf.concat(h_conv_split[i * cols:(i + 1) * cols], axis=2)
            for i in range(rows)
        ],
                                 axis=1)
        tf.summary.image(name, h_conv_tiled, 3)

    sess = tf.InteractiveSession(config=tf.ConfigProto(
        allow_soft_placement=True))

    # input placeholders
    classes = 10
    with tf.name_scope('input'):
        x = tf.placeholder(tf.float32, shape=[None, 32, 32])
        y_ = tf.placeholder(tf.int64, shape=[None])
        keep_prob = tf.placeholder(tf.float32)

        x_image = tf.reshape(x, [-1, 32, 32, 1])
        paddings = tf.constant([[
            0,
            0,
        ], [padamt, padamt], [padamt, padamt], [0, 0]])
        x_image = tf.pad(x_image, paddings)
        # x_image = tf.image.resize_nearest_neighbor(x_image, size=(dim, dim))
        tf.summary.image('input', x_image, 3)

        # if not isNonNeg and not doNonnegReg:
        #     x_image -= tf.reduce_mean(x_image)

    # regularizers
    global_step = tf.Variable(0, trainable=False)
    if doNonnegReg:
        reg_scale = tf.train.polynomial_decay(0.,
                                              global_step,
                                              decay_steps=6000,
                                              end_learning_rate=6000.)
        psf_reg = optics_alt.nonneg_regularizer(reg_scale)
    else:
        psf_reg = None

    l2_reg = tf.contrib.layers.l2_regularizer(1e-1, scope=None)

    # build model
    h_conv_out = x_image
    fcdepth = 1
    doVis = True

    if doConv1:
        with tf.name_scope('conv1'):
            if doTiledConv:
                tiled_dim = (32) * rows
                init_vals_pos = tf.truncated_normal(
                    [tiled_dim, tiled_dim, 1, 1], stddev=0.1) + .1
                W_conv1_tiled = tf.Variable(init_vals_pos,
                                            name='W_conv1_tiled')
                W_conv1_tiled = nonneg(W_conv1_tiled)
                tf.summary.image(
                    "W_conv1_tiled",
                    tf.expand_dims(tf.squeeze(W_conv1_tiled, -1), 0))

                tile_pad = tiled_dim // 2 - 16
                tile_paddings = tf.constant([[
                    0,
                    0,
                ], [tile_pad, tile_pad], [tile_pad, tile_pad], [0, 0]])
                x_padded = tf.pad(x_image, tile_paddings)
                tf.summary.image('input', x_padded, 3)

                fftpadamt = int(tiled_dim / 2)
                h_conv_tiled = tf.abs(
                    optics.fft_conv2d(fftpad(x_padded, fftpadamt),
                                      fftpad_psf(W_conv1_tiled, fftpadamt)))
                h_conv_tiled = fftunpad(
                    tf.cast(h_conv_tiled, dtype=tf.float32), fftpadamt)

                h_conv_split2d = split2d_layer(h_conv_tiled, rows, cols)
                b_conv1 = bias_variable([depth1], 'b_conv1')
                h_conv1 = h_conv_split2d + b_conv1
            elif doOpticalConv:
                tiled_dim = (32) * cols
                tile_pad = tiled_dim // 2 - 16
                tile_paddings = tf.constant([[
                    0,
                    0,
                ], [tile_pad, tile_pad], [tile_pad, tile_pad], [0, 0]])
                x_padded = tf.pad(x_image, tile_paddings)
                tf.summary.image('input', x_padded, 3)

                r_NA = tiled_dim / 2
                hm_reg_scale = 1e-2
                # initialize with optimized phase mask
                # mask = np.load('maskopt/quickdraw9_zernike1024.npy')
                # initializer = tf.constant_initializer(mask)
                initializer = None

                h_conv1_opt = optical_conv_layer(
                    x_padded,
                    hm_reg_scale,
                    r_NA,
                    n=1.48,
                    wavelength=wavelength,
                    activation=None,
                    amplitude_mask=doAmplitudeMask,
                    zernike=doZernike,
                    n_modes=z_modes,
                    initializer=initializer,
                    name='opt_conv1_pos')

                # h_conv1_opt_neg = optical_conv_layer(x_padded, hm_reg_scale, r_NA, n=1.48, wavelength=wavelength,
                #        activation=None, amplitude_mask=doAmplitudeMask, zernike=doZernike,
                #        n_modes=z_modes, initializer=initializer, name='opt_conv1_neg')

                h_conv1_opt = tf.cast(h_conv1_opt, dtype=tf.float32)
                h_conv_split2d = split2d_layer(h_conv1_opt, 2 * rows, cols)
                b_conv1 = bias_variable([depth1], 'b_conv1')
                h_conv1 = h_conv_split2d + b_conv1

            else:
                if doOptNeg:
                    # positive weights
                    init_vals_pos = tf.truncated_normal(
                        [convdim1, convdim1, 1, depth1], stddev=0.1) + .1
                    W_conv1_pos = tf.Variable(init_vals_pos,
                                              name='W_conv1_pos')
                    # W_conv1 = weight_variable([convdim1, convdim1, 1, depth1], name='W_conv1')
                    W_conv1_pos = nonneg(W_conv1_pos)
                    #W_conv1_nonneg /= tf.reduce_sum(tf.abs(W_conv1_nonneg)) # conservation of energy
                    tf.contrib.layers.apply_regularization(
                        l2_reg,
                        weights_list=[tf.transpose(W_conv1_pos, [3, 0, 1, 2])])

                    # negative weights
                    init_vals_neg = tf.truncated_normal(
                        [convdim1, convdim1, 1, depth1], stddev=0.1) + .1
                    W_conv1_neg = tf.Variable(init_vals_neg,
                                              name='W_conv1_neg')
                    # W_conv1 = weight_variable([convdim1, convdim1, 1, depth1], name='W_conv1')
                    W_conv1_neg = nonneg(W_conv1_neg)
                    # W_conv1_nonneg /= tf.reduce_sum(tf.abs(W_conv1_nonneg)) # conservation of energy
                    tf.contrib.layers.apply_regularization(
                        l2_reg,
                        weights_list=[tf.transpose(W_conv1_neg, [3, 0, 1, 2])])

                    W_conv1 = tf.subtract(W_conv1_pos, W_conv1_neg)

                    if doVis:
                        vis_weights(W_conv1_pos, depth1, buff, rows, cols,
                                    'W_conv1_pos')
                        vis_weights(W_conv1_neg, depth1, buff, rows, cols,
                                    'W_conv1_neg')

                elif isNonNeg:
                    init_vals = tf.truncated_normal(
                        [convdim1, convdim1, 1, depth1], stddev=0.1)
                    W_conv1 = tf.Variable(init_vals, name='W_conv1_nn') + .1
                    # W_conv1 = weight_variable([convdim1, convdim1, 1, depth1], name='W_conv1')
                    W_conv1 = nonneg(W_conv1)
                    #W_conv1_nonneg /= tf.reduce_sum(tf.abs(W_conv1_nonneg)) # conservation of energy
                else:
                    W_conv1 = weight_variable([convdim1, convdim1, 1, depth1],
                                              name='W_conv1')

                    if psf_reg is not None:
                        tf.contrib.layers.apply_regularization(
                            psf_reg,
                            weights_list=[tf.transpose(W_conv1, [3, 0, 1, 2])])

                vis_weights(W_conv1, depth1, buff, rows, cols, 'W_conv1')

                W_conv1_flip = tf.reverse(W_conv1,
                                          axis=[0, 1])  # flip if using tfconv
                h_conv1 = conv2d(x_image, W_conv1_flip)
                h_conv1 /= tf.reduce_max(h_conv1,
                                         axis=[1, 2, 3],
                                         keep_dims=True)

                b_conv1 = bias_variable([depth1], 'b_conv1')
                h_conv1 = h_conv1 + b_conv1

            vis_h(h_conv1, depth1, rows, cols, 'h_conv1')
            variable_summaries("h_conv1", h_conv1)
            h_conv1_drop = tf.nn.dropout(h_conv1, keep_prob)

            #h_pool1 = max_pool_2x2(h_conv1)
            h_pool1 = h_conv1_drop

            if doNonnegReg:
                h_pool1 = optics_alt.shifted_relu(h_pool1)
            else:
                h_pool1 = activation(h_pool1)
            variable_summaries("h_conv1_post", h_pool1)

            h_conv_out = h_pool1
            #dim = 16
            fcdepth = depth1

    if doConv2:
        with tf.name_scope('conv2'):
            W_conv2 = weight_variable([convdim2, convdim2, depth1, depth2],
                                      name='W_conv2')
            # vis_weights(W_conv2, depth2, buff, rows, cols, 'W_conv2')
            b_conv2 = bias_variable([depth2], name='b_conv2')
            h_conv2 = conv2d(h_pool1, W_conv2) + b_conv2

            # h_pool2 = max_pool_2x2(h_conv2)
            h_pool2 = h_conv2
            variable_summaries("h_conv2", h_pool2)

            h_conv2_drop = tf.nn.dropout(h_pool2, keep_prob)
            h_conv2_drop = activation(h_conv2_drop)
            variable_summaries("h_conv2_post", h_conv2_drop)
            h_conv_out = h_conv2_drop
            # dim = 16
            fcdepth = depth2

    if doConv3:
        with tf.name_scope('conv3'):
            W_conv3 = weight_variable([convdim3, convdim3, depth2, depth3],
                                      name='W_conv3')
            # vis_weights(W_conv3, depth3, buff, rows, cols, 'W_conv3')
            b_conv3 = bias_variable([depth3], name='b_conv3')

            h_conv3 = conv2d(h_pool2, W_conv3) + b_conv3
            h_pool3 = max_pool_2x2(h_conv3)
            variable_summaries("h_conv3", h_pool3)

            h_conv3_drop = tf.nn.dropout(h_pool3, keep_prob)
            h_conv3_drop = activation(h_conv3_drop)
            variable_summaries("h_conv3_post", h_conv3_drop)
            h_conv_out = h_conv3_drop
            fcdepth = depth3
            dim = 16

    # choose output layer here
    with tf.name_scope('fc'):
        h_conv_out = tf.cast(h_conv_out, dtype=tf.float32)

        fcsize = dim * dim * fcdepth
        hidden_dim = classes
        W_fc1 = weight_variable([fcsize, hidden_dim], name='W_fc1')
        b_fc1 = bias_variable([hidden_dim], name='b_fc1')
        h_conv_flat = tf.reshape(h_conv_out, [-1, fcsize])

        y_out = tf.matmul(h_conv_flat, W_fc1) + b_fc1

        # h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

        # W_fc2 = weight_variable([hidden_dim, classes])
        # b_fc2 = bias_variable([classes])
        # y_out = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

    tf.summary.image('output', tf.reshape(y_out, [-1, 2, 5, 1]), 3)

    # loss, train, acc
    with tf.name_scope('cross_entropy'):
        total_data_loss = tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.one_hot(y_, classes), logits=y_out)
        data_loss = tf.reduce_mean(total_data_loss)
        reg_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = tf.add(data_loss, reg_loss)
        tf.summary.scalar('data_loss', data_loss)
        tf.summary.scalar('reg_loss', reg_loss)
        tf.summary.scalar('total_loss', total_loss)

    if opt_type == 'ADAM':
        train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(
            total_loss, global_step)
    elif opt_type == 'adadelta':
        train_step = tf.train.AdadeltaOptimizer(FLAGS.learning_rate_ad,
                                                rho=.9).minimize(
                                                    total_loss, global_step)
    else:
        train_step = tf.train.MomentumOptimizer(FLAGS.learning_rate,
                                                momentum=0.5,
                                                use_nesterov=True).minimize(
                                                    total_loss, global_step)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(y_out, 1), y_)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

    losses = []

    # tensorboard setup
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
    test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')

    tf.global_variables_initializer().run()

    # add ops to save and restore all the variables
    saver = tf.train.Saver(max_to_keep=2)
    save_path = os.path.join(FLAGS.log_dir, 'model.ckpt')

    x_train_all, y_train_all, x_test, y_test, _, _ = get_CIFAR10_grayscale(
        num_training=49000, num_validation=1000, num_test=0)
    num_training = x_train_all.shape[0]

    def get_feed(train, batch_size=50, augmentation=False):
        idcs = np.random.randint(0, num_training, batch_size)
        x = x_train_all[idcs, :, :]
        y = y_train_all[idcs]

        if augmentation:
            angle = np.random.uniform(low=0.0, high=20.0)
            x = rotate(x, angle, axes=(2, 1), reshape=True)
            x = resize(x, (32, 32))

        return x, y

    for i in range(FLAGS.num_iters):
        x_train, y_train = get_feed(train=True, augmentation=False)
        _, loss, reg_loss_graph, train_accuracy, train_summary = sess.run(
            [train_step, total_loss, reg_loss, accuracy, merged],
            feed_dict={
                x: x_train,
                y_: y_train,
                keep_prob: FLAGS.dropout
            })
        losses.append(loss)

        if i % summary_every == 0:
            train_writer.add_summary(train_summary, i)

            test_summary, test_accuracy = sess.run([merged, accuracy],
                                                   feed_dict={
                                                       x: x_test,
                                                       y_: y_test,
                                                       keep_prob: 1.0
                                                   })
            test_writer.add_summary(test_summary, i)
            if verbose:
                print('step %d: test acc %g' % (i, test_accuracy))

        if i > 0 and i % save_every == 0:
            # print("Saving model...")
            saver.save(sess, save_path, global_step=i)

        if i % print_every == 0:
            if verbose:
                print('step %d:\t loss %g,\t reg_loss %g,\t train acc %g' %
                      (i, loss, reg_loss_graph, train_accuracy))

    #test_batches = []
    # for i in range(4):
    #     idx = i*500
    #     batch_acc = accuracy.eval(feed_dict={x: x_test[idx:idx+500, :], y_: y_test[idx:idx+500], keep_prob: 1.0})
    #     test_batches.append(batch_acc)
    # test_acc = np.mean(test_batches)

    test_acc = accuracy.eval(feed_dict={x: x_test, y_: y_test, keep_prob: 1.0})
    print('final step %d, train accuracy %g, test accuracy %g' %
          (i, train_accuracy, test_acc))
    #sess.close()

    train_writer.close()
    test_writer.close()