def _build_graph(self, x_train, hm_reg_scale, init_gamma, height_map_noise): input_img, depth_map = x_train with tf.device('/device:GPU:0'): height_map = optics.get_fourier_height_map(self.wave_resolution[0], 0.75, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale)) optical_system = optics.SingleLensSetup(height_map=height_map, wave_resolution=self.wave_resolution, wave_lengths=self.wave_lengths, sensor_distance=self.sensor_distance, sensor_resolution=(self.patch_size, self.patch_size), input_sample_interval=self.sampling_interval, refractive_idcs=self.refractive_idcs, height_tolerance=height_map_noise, use_planar_incidence=False, depth_bins=self.depth_bins, upsample=False, psf_resolution=self.wave_resolution, target_distance=None) sensor_img = optical_system.get_sensor_img(input_img=input_img, noise_sigma=None, depth_dependent=True, depth_map=depth_map) U_net = net.U_Net() output_image = U_net.build(sensor_img) optics.attach_summaries('output_image', output_image, image=True, log_image=False) return output_image
def _build_graph(self, x_train, global_step, hm_reg_scale, height_map_noise): ''' Builds the graph for this model. :param x_train (graph node): input image. :param global_step: Global step variable (unused in this model) :param hm_reg_scale: Regularization coefficient for laplace l1 regularizer of phaseplate :param height_map_noise: Noise added to height map to account for manufacturing deficiencies :return: graph node for output image ''' input_img = x_train with tf.device('/device:GPU:0'): # Input field is a planar wave. input_field = tf.ones((1, self.wave_res[0], self.wave_res[1], len(self.wave_lengths))) # Planar wave hits aperture: phase is shifted by phaseplate field = optics.height_map_element( input_field, wave_lengths=self.wave_lengths, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=None, height_tolerance=height_map_noise, refractive_idcs=self.refractive_idcs, name='height_map_optics') field = optics.circular_aperture(field) # Propagate field from aperture to sensor field = optics.propagate_fresnel( field, distance=self.sensor_distance, sampling_interval=self.sample_interval, wave_lengths=self.wave_lengths) # The psf is the intensities of the propagated field. psfs = optics.get_intensities(field) # Downsample psf to image resolution & normalize to sum to 1 psfs = optics.area_downsampling_tf(psfs, self.patch_size) psfs = tf.div(psfs, tf.reduce_sum(psfs, axis=[1, 2], keepdims=True)) optics.attach_summaries('PSF', psfs, image=True, log_image=True) # Image formation: PSF is convolved with input image psfs = tf.transpose(psfs, [1, 2, 0, 3]) output_image = optics.img_psf_conv(input_img, psfs) output_image = tf.cast(output_image, tf.float32) optics.attach_summaries('output_image', output_image, image=True, log_image=False) output_image += tf.random_uniform(minval=0.001, maxval=0.02, shape=[]) return output_image
def _get_data_loss(self, model_output, ground_truth, margin=10): model_output = tf.cast(model_output, tf.float32) ground_truth = tf.cast(ground_truth, tf.float32) loss = tf.reduce_mean( tf.square(model_output - ground_truth)[:, margin:-margin, margin:-margin, :]) optics.attach_summaries('output_image', model_output[:, margin:-margin, margin:-margin, :], image=True, log_image=False) return loss
def _build_graph(self, x_train, hm_reg_scale, init_gamma, height_map_noise): input_img, depth_map = x_train with tf.device('/device:GPU:0'): U_net = net.U_Net() output_image = U_net.build(input_img) optics.attach_summaries('output_image', output_image, image=True, log_image=False) return output_image
def _build_graph(self, x_train, global_step, hm_reg_scale, init_gamma, height_map_noise, learned_target_depth, hm_init_type='random_normal'): input_img, depth_map = x_train with tf.device('/device:GPU:0'): with tf.variable_scope("optics"): height_map = optics.get_fourier_height_map(self.wave_resolution[0], 0.625, height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale)) target_depth_initializer = tf.constant_initializer(1.) target_depth = tf.get_variable(name="target_depth", shape=(), dtype=tf.float32, trainable=True, initializer=target_depth_initializer) target_depth = tf.square(target_depth) tf.summary.scalar('target_depth', target_depth) optical_system = optics.SingleLensSetup(height_map=height_map, wave_resolution=self.wave_resolution, wave_lengths=self.wave_lengths, sensor_distance=self.distance, sensor_resolution=(self.patch_size, self.patch_size), input_sample_interval=self.input_sample_interval, refractive_idcs=self.refractive_idcs, height_tolerance=height_map_noise, use_planar_incidence=False, depth_bins=self.depth_bins, upsample=False, psf_resolution=self.wave_resolution, target_distance=target_depth) noise_sigma = tf.random_uniform(minval=0.001, maxval=0.02, shape=[]) sensor_img = optical_system.get_sensor_img(input_img=input_img, noise_sigma=noise_sigma, depth_dependent=True, depth_map=depth_map) output_image = tf.cast(sensor_img, tf.float32) # Now we deconvolve pad_width = output_image.shape.as_list()[1]//2 output_image = tf.pad(output_image, [[0,0],[pad_width, pad_width],[pad_width,pad_width],[0,0]], mode='SYMMETRIC') output_image = deconv.inverse_filter(output_image, output_image, optical_system.target_psf, init_gamma=init_gamma) output_image = output_image[:,pad_width:-pad_width,pad_width:-pad_width,:] optics.attach_summaries('output_image', output_image, image=True, log_image=False) return output_image
def optical_conv_layer(input_field, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9, activation=None, coherent=False, amplitude_mask=False, zernike=False, fourier=False, binarymask=False, n_modes=1024, freq_range=.5, initializer=None, zernike_file='zernike_volume_256.npy', binary_mask_np=None, name='optical_conv'): dims = input_field.get_shape().as_list() with tf.variable_scope(name): if initializer is None: initializer = tf.random_uniform_initializer(minval=0.999e-4, maxval=1.001e-4) if amplitude_mask: # Build an amplitude mask, zero-centered # amplitude_map_initializer=tf.random_uniform_initializer(minval=0.1e-3, maxval=.1e-2) mask = optics.amplitude_map_element( [1, dims[1], dims[2], 1], r_NA, amplitude_map_initializer=initializer, name='amplitude_mask') else: # Build a phase mask, zero-centered if zernike: zernike_modes = np.load(zernike_file)[:n_modes, :, :] zernike_modes = tf.expand_dims(zernike_modes, -1) zernike_modes = tf.image.resize_images(zernike_modes, size=(dims[1], dims[2])) zernike_modes = tf.squeeze(zernike_modes, -1) mask = optics.zernike_element(zernike_modes, 'zernike_element', wavelength, n, r_NA, zernike_initializer=initializer) elif fourier: mask = optics.fourier_element( [1, dims[1], dims[2], 1], 'fourier_element', wave_lengths=wavelength, refractive_index=n, frequency_range=freq_range, height_map_regularizer=None, height_tolerance=None, # Default height tolerance is 2 nm. ) else: # height_map_initializer=None mask = optics.height_map_element( [1, dims[1], dims[2], 1], wave_lengths=wavelength, height_map_initializer=initializer, #height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale), name='phase_mask_height', refractive_index=n) # Get ATF and PSF atf = tf.ones([1, dims[1], dims[2], 1]) #zernike=True if zernike: atf = optics.circular_aperture(atf, max_val=r_NA) atf = mask(atf) # apply any additional binary amplitude mask [1, dim, dim, 1] if binarymask: binary_mask = tf.convert_to_tensor(binary_mask_np, dtype=tf.float32) binary_mask = tf.expand_dims(tf.expand_dims(binary_mask, 0), -1) # optics.attach_img('binary_mask', binary_mask) atf = atf * tf.cast(binary_mask, tf.complex128) optics.attach_img('atf', tf.abs(binary_mask)) # psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(atf))) psfc = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(atf))) psf = optics.Sensor(input_is_intensities=False, resolution=(dims[1], dims[2]))(psfc) psf /= tf.reduce_sum(psf) # conservation of energy psf = tf.cast(psf, tf.float32) optics.attach_summaries('psf', psf, True, True) # Get the output image if coherent: input_field = tf.cast(input_field, tf.complex128) # Zero-centered fft2 of input field field = optics.fftshift2d_tf( optics.transp_fft2d(optics.ifftshift2d_tf(input_field))) # field = optics.circular_aperture(field, max_val = r_NA) field = atf * field tf.summary.image('field', tf.log(tf.square(tf.abs(field)))) output_img = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(field))) output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1], dims[2]))(output_img) # does this need padding as well? # psfc = tf.expand_dims(tf.expand_dims(tf.squeeze(psfc), -1), -1) # padamt = int(dims[1]/2) # output_img = optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psfc, padamt), adjoint=False) # output_img = fftunpad(output_img, padamt) # output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1],dims[2]))(output_img) else: psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) # psf_flip = tf.reverse(psf,[0,1]) # output_img = conv2d(input_field, psf) padamt = int(dims[1] / 2) output_img = tf.abs( optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psf, padamt), adjoint=False)) output_img = fftunpad(output_img, padamt) # Apply nonlinear activation if activation is not None: output_img = activation(output_img) return output_img
def train(params, summary_every=100, save_every=2000, verbose=True): # Unpack params isNonNeg = params.get('isNonNeg', False) addBias = params.get('addBias', True) doLogTrans = params.get('logtrans', False) numIters = params.get('numIters', 1000) activation = params.get('activation', tf.nn.relu) # constraint helpers def nonneg(input_tensor): return tf.square(input_tensor) if isNonNeg else input_tensor sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True)) # input placeholders classes = 9 with tf.name_scope('input'): x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.float32, shape=[None, classes]) keep_prob = tf.placeholder(tf.float32) x_image = tf.reshape(x, [-1, 28, 28, 1]) padamt = 0 dim = 28+2*padamt paddings = tf.constant([[0, 0,], [padamt, padamt], [padamt, padamt], [0, 0]]) x_image = tf.pad(x_image, paddings) tf.summary.image('input', x_image, 3) # build model with tf.device('/device:GPU:2'): doOpticalConv=False doConv=False if doConv: if doOpticalConv: doAmplitudeMask=True hm_reg_scale = 1e-2 r_NA = 35 h_conv1 = optical_conv_layer(x_image, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9, activation=activation, amplitude_mask=doAmplitudeMask, name='opt_conv1') h_conv1 = tf.cast(h_conv1, dtype=tf.float32) else: W_conv1 = weight_variable([dim, dim, 1, 1], name='W_conv1') W_conv1 = nonneg(W_conv1) W_conv1_im = tf.expand_dims(tf.expand_dims(tf.squeeze(W_conv1), 0),3) optics.attach_summaries("W_conv1", W_conv1_im, image=True) # W_conv1 = weight_variable([12, 12, 1, 9]) h_conv1 = activation(conv2d(x_image, (W_conv1))) optics.attach_summaries("h_conv1", h_conv1, image=True) h_conv1_drop = tf.nn.dropout(h_conv1, keep_prob) # h_conv1_split = tf.split(h_conv1, 9, axis=3) # h_conv1_tiled = tf.concat([tf.concat(h_conv1_split[:3], axis=1), # tf.concat(h_conv1_split[3:6], axis=1), # tf.concat(h_conv1_split[6:9], axis=1)], axis=2) # tf.summary.image("h_conv1", h_conv1_tiled, 3) split_1d = tf.split(h_conv1_drop, num_or_size_splits=3, axis=1) h_conv1_split = tf.concat([tf.split(split_1d[0], num_or_size_splits=3, axis=2), tf.split(split_1d[1], num_or_size_splits=3, axis=2), tf.split(split_1d[2], num_or_size_splits=3, axis=2)], 0) y_out = tf.transpose(tf.reduce_max(h_conv1_split, axis=[2,3,4])) else: with tf.name_scope('fc'): fcsize = dim*dim W_fc1 = weight_variable([fcsize, classes], name='W_fc1') W_fc1 = nonneg(W_fc1) # visualize the FC weights W_fc1_split = tf.reshape(tf.transpose(W_fc1), [classes, 28, 28]) W_fc1_split = tf.split(W_fc1_split, classes, axis=0) W_fc1_tiled = tf.concat([tf.concat(W_fc1_split[:3], axis=2), tf.concat(W_fc1_split[3:6], axis=2), tf.concat(W_fc1_split[6:9], axis=2)], axis=1) tf.summary.image("W_fc1", tf.expand_dims(W_fc1_tiled, 3)) h_conv1_flat = tf.reshape(x_image, [-1, fcsize]) y_out = (tf.matmul(h_conv1_flat, (W_fc1))) tf.summary.image('output', tf.reshape(y_out, [-1, 3, 3, 1]), 3) # loss, train, acc with tf.name_scope('cross_entropy'): total_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_out) mean_loss = tf.reduce_mean(total_loss) tf.summary.scalar('loss', mean_loss) train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(mean_loss) with tf.name_scope('accuracy'): correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', accuracy) losses = [] # tensorboard setup merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test') tf.global_variables_initializer().run() # add ops to save and restore all the variables saver = tf.train.Saver(max_to_keep=2) save_path = os.path.join(FLAGS.log_dir, 'model.ckpt') # MNIST feed dict def get_feed(train): if train: x, y = mnist.train.next_batch(50) else: x = mnist.test.images y = mnist.test.labels # remove "0"s indices = ~np.equal(y[:,0], 1) x_filt = np.squeeze(x[indices]) y_filt = np.squeeze(y[indices,1:]) return x_filt, y_filt # QuickDraw feed dict # train_data = np.load('/media/data/Datasets/quickdraw/split/all_train.npy') # test_data = np.load('/media/data/Datasets/quickdraw/split/all_test.npy') # def get_feed(train, batch_size=50): # if train: # idcs = np.random.randint(0, np.shape(train_data)[0], batch_size) # x = train_data[idcs, :] # categories = idcs//4000 # y = np.zeros((batch_size, classes)) # y[np.arange(batch_size), categories] = 1 # else: # x = test_data # y = np.resize(np.equal(range(classes),0).astype(int),(100,classes)) # for i in range(1,classes): # y = np.concatenate((y, np.resize(np.equal(range(classes),i).astype(int),(100,classes))), axis=0) # return x, y x_test, y_test = get_feed(train=False) for i in range(FLAGS.num_iters): x_train, y_train = get_feed(train=True) _, loss = sess.run([train_step, mean_loss], feed_dict={x: x_train, y_: y_train, keep_prob: FLAGS.dropout}) losses.append(loss) if i % summary_every == 0: train_summary, train_accuracy = sess.run([merged, accuracy], feed_dict={ x: x_train, y_: y_train, keep_prob: FLAGS.dropout}) train_writer.add_summary(train_summary, i) test_summary, test_accuracy = sess.run([merged, accuracy], feed_dict={ x: x_test, y_: y_test, keep_prob: 1.0}) # test_writer.add_summary(test_summary, i) if verbose: print('step %d: loss %g, train acc %g, test acc %g' % (i, loss, train_accuracy, test_accuracy)) if i % save_every == 0: print("Saving model...") saver.save(sess, save_path, global_step=i) test_acc = accuracy.eval(feed_dict={x: x_test, y_: y_test, keep_prob: 1.0}) print('final step %d, train accuracy %g, test accuracy %g' % (i, train_accuracy, test_acc)) #sess.close() train_writer.close() test_writer.close()
def _build_graph(self, x_train, init_gamma, hm_reg_scale, noise_sigma, height_map_noise, hm_init_type='random_normal'): input_img = x_train print("build graph", input_img.get_shape()) with tf.device('/device:GPU:0'): def forward_model(input_field): field = optics.height_map_element( input_field, wave_lengths=self.wave_lengths, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=None, height_tolerance=height_map_noise, refractive_idcs=self.refractive_idcs, name='height_map_optics') field = optics.circular_aperture(field) field = optics.propagate_fresnel( field, distance=self.distance, input_sample_interval=self.input_sample_interval, wave_lengths=self.wave_lengths) return field optical_system = optics.OpticalSystem( forward_model, upsample=False, wave_resolution=self.wave_resolution, wave_lengths=self.wave_lengths, sensor_resolution=(self.patch_size, self.patch_size), psf_resolution=(self.patch_size, self.patch_size), # Equals wave resolution discretization_size=self.input_sample_interval, use_planar_incidence=True) if noise_sigma is None: noise_sigma = tf.random_uniform(minval=0.001, maxval=0.02, shape=[]) sensor_img = optical_system.get_sensor_img(input_img=input_img, noise_sigma=noise_sigma, depth_dependent=False) output_image = tf.cast(sensor_img, tf.float32) # Now deconvolve pad_width = output_image.shape.as_list()[1] // 2 output_image = tf.pad(output_image, [[0, 0], [pad_width, pad_width], [pad_width, pad_width], [0, 0]]) output_image = deconv.inverse_filter(output_image, output_image, optical_system.psfs[0], init_gamma=init_gamma) output_image = output_image[:, pad_width:-pad_width, pad_width:-pad_width, :] optics.attach_summaries('output_image', output_image, image=True, log_image=False) return output_image
def train(params, summary_every=100, print_every=250, save_every=1000, verbose=True): # Unpack params isNonNeg = params.get('isNonNeg', False) # addBias = params.get('addBias', True) numIters = params.get('numIters', 1000) activation = params.get('activation', tf.nn.relu) # constraint helpers def nonneg(input_tensor): return tf.square(input_tensor) if isNonNeg else input_tensor sess = tf.InteractiveSession(config=tf.ConfigProto( allow_soft_placement=True)) # input placeholders classes = 9 with tf.name_scope('input'): x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.float32, shape=[None, classes]) keep_prob = tf.placeholder(tf.float32) x_image = tf.reshape(x, [-1, 28, 28, 1]) padamt = 28 dim = 84 paddings = tf.constant([[ 0, 0, ], [padamt, padamt], [padamt, padamt], [0, 0]]) x_image = tf.pad(x_image, paddings) x_image = tf.image.resize_nearest_neighbor(x_image, size=(dim, dim)) tf.summary.image('input', x_image, 3) # build model if True: doOpticalConv = True if doOpticalConv: doAmplitudeMask = False hm_reg_scale = 1e-2 r_NA = 35 # initialize with optimized phase mask mask = np.load( 'maskopt/opticalcorrelator_w-conv1_height-map-sqrt.npy') initializer = tf.constant_initializer(mask) # initializer=None h_conv1 = optical_conv_layer(x_image, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9, activation=activation, amplitude_mask=doAmplitudeMask, initializer=initializer, name='opt_conv1') # h_conv2 = optical_conv_layer(h_conv1, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9, # activation=activation, amplitude_mask=doAmplitudeMask, name='opt_conv2') else: conv1dim = dim W_conv1 = weight_variable([conv1dim, conv1dim, 1, 1], name='W_conv1') W_conv1_flip = tf.reverse(W_conv1, axis=[0, 1]) # W_conv1 = weight_variable([12, 12, 1, 9]) W_conv1_im = tf.expand_dims(tf.expand_dims(tf.squeeze(W_conv1), 0), 3) optics.attach_summaries("W_conv1", W_conv1_im, image=True) h_conv1 = activation(conv2d(x_image, nonneg(W_conv1_flip))) # h_conv1_drop = tf.nn.dropout(h_conv1, keep_prob) # W_conv2 = weight_variable([48, 48, 1, 1], name='W_conv2') # W_conv2 = weight_variable([12, 12, 9, 9]) # h_conv2 = activation(conv2d(h_conv1_drop, nonneg(W_conv2))) # h_conv1_split = tf.split(h_conv1, 9, axis=3) # h_conv1_tiled = tf.concat([tf.concat(h_conv1_split[:3], axis=1), # tf.concat(h_conv1_split[3:6], axis=1), # tf.concat(h_conv1_split[6:9], axis=1)], axis=2) # tf.summary.image("h_conv1", h_conv1_tiled, 3) # h_conv2_split = tf.split(h_conv2, 9, axis=3) # h_conv2_tiled = tf.concat([tf.concat(h_conv2_split[:3], axis=1), # tf.concat(h_conv2_split[3:6], axis=1), # tf.concat(h_conv2_split[6:9], axis=1)], axis=2) # tf.summary.image("h_conv2", h_conv2_tiled, 3) optics.attach_summaries("h_conv1", h_conv1, image=True) #optics.attach_summaries("h_conv2", h_conv2, image=True) # h_conv2 = x_image doFC = False if doFC: with tf.name_scope('fc'): h_conv1 = tf.cast(h_conv1, dtype=tf.float32) fcsize = dim * dim W_fc1 = weight_variable([fcsize, classes], name='W_fc1') h_conv1_flat = tf.reshape(h_conv1, [-1, fcsize]) y_out = (tf.matmul(h_conv1_flat, nonneg(W_fc1))) # h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) # W_fc2 = weight_variable([hidden_dim, 10]) # y_out = tf.matmul(h_fc1_drop, nonneg(W_fc2)) else: doConv2 = False if doConv2: if doOpticalConv: h_conv2 = optical_conv_layer(h_conv1, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9, activation=activation, name='opt_conv2') h_conv2 = tf.cast(h_conv2, dtype=tf.float32) else: W_conv2 = weight_variable([dim, dim, 1, 1]) W_conv2_flip = tf.reverse(W_conv2, axis=[0, 1]) W_conv2_im = tf.expand_dims( tf.expand_dims(tf.squeeze(W_conv2), 0), 3) optics.attach_summaries("W_conv2", W_conv2_im, image=True) h_conv2 = activation(conv2d(h_conv1, nonneg(W_conv2_flip))) W_conv3 = weight_variable([dim, dim, 1, 1]) W_conv3_flip = tf.reverse(W_conv3, axis=[0, 1]) W_conv3_im = tf.expand_dims( tf.expand_dims(tf.squeeze(W_conv3), 0), 3) optics.attach_summaries("W_conv3", W_conv3_im, image=True) h_conv3 = activation(conv2d(h_conv2, nonneg(W_conv3_flip))) tf.summary.image("h_conv2", h_conv2) tf.summary.image("h_conv3", h_conv3) split_1d = tf.split(h_conv3, num_or_size_splits=3, axis=1) else: split_1d = tf.split(h_conv1, num_or_size_splits=3, axis=1) h_conv_split = tf.concat([ tf.split(split_1d[0], num_or_size_splits=3, axis=2), tf.split(split_1d[1], num_or_size_splits=3, axis=2), tf.split(split_1d[2], num_or_size_splits=3, axis=2) ], 0) # h_conv2_split1, h_conv2_split2 = tf.split(h_conv2, num_or_size_splits=2, axis=1) # h_conv2_split = tf.concat([tf.split(h_conv2_split1, num_or_size_splits=5, axis=2), # tf.split(h_conv2_split2, num_or_size_splits=5, axis=2)], 0) y_out = tf.transpose(tf.reduce_max(h_conv_split, axis=[2, 3, 4])) tf.summary.image('output', tf.reshape(y_out, [-1, 3, 3, 1]), 3) # loss, train, acc with tf.name_scope('cross_entropy'): total_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_out) mean_loss = tf.reduce_mean(total_loss) tf.summary.scalar('loss', mean_loss) # train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(mean_loss) train_step = tf.train.AdadeltaOptimizer(FLAGS.learning_rate, rho=1.0).minimize(mean_loss) with tf.name_scope('accuracy'): correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', accuracy) losses = [] # tensorboard setup merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test') tf.global_variables_initializer().run() # add ops to save and restore all the variables saver = tf.train.Saver(max_to_keep=2) save_path = os.path.join(FLAGS.log_dir, 'model.ckpt') def get_feed(train): if train: x, y = mnist.train.next_batch(50) else: x = mnist.test.images y = mnist.test.labels # remove "0"s indices = ~np.equal(y[:, 0], 1) x_filt = np.squeeze(x[indices]) y_filt = np.squeeze(y[indices, 1:]) return x_filt, y_filt x_test, y_test = get_feed(train=False) for i in range(FLAGS.num_iters): x_train, y_train = get_feed(train=True) _, loss, train_accuracy, train_summary = sess.run( [train_step, mean_loss, accuracy, merged], feed_dict={ x: x_train, y_: y_train, keep_prob: FLAGS.dropout }) losses.append(loss) if i % summary_every == 0: train_writer.add_summary(train_summary, i) if i > 0 and i % save_every == 0: # print("Saving model...") saver.save(sess, save_path, global_step=i) # test_summary, test_accuracy = sess.run([merged, accuracy], # feed_dict={x: x_test, y_: y_test, keep_prob: 1.0}) # test_writer.add_summary(test_summary, i) # if verbose: # print('step %d: test acc %g' % (i, test_accuracy)) if i % print_every == 0: if verbose: print('step %d: loss %g, train acc %g' % (i, loss, train_accuracy)) # test_acc = accuracy.eval(feed_dict={x: x_test, y_: y_test, keep_prob: 1.0}) # print('final step %d, train accuracy %g, test accuracy %g' % # (i, train_accuracy, test_acc)) #sess.close() train_writer.close() test_writer.close()