def test_input_placeholders(self): with tf.Graph().as_default(): # Given image_pixels = 100 num_classes = 10 # When x, y_ = input.input_placeholders(image_pixels, num_classes) # Then self.assertIsNotNone(x) self.assertEqual(type(x).__name__, "Tensor") self.assertEqual(x.__dict__["_shape"].as_list(), [None, 100]) self.assertIsNotNone(y_) self.assertEqual(type(y_).__name__, "Tensor") self.assertEqual(y_.__dict__["_shape"].as_list(), [None, 10])
def main(): # Create an InteractiveSession sess = tf.InteractiveSession() # Remove tensorboard previous directory if os.path.exists(FLAGS.summaries_dir): shutil.rmtree(FLAGS.summaries_dir) """ Step 1 - Input data management """ # MNIST data mnist = input.input_reader(FLAGS.data_dir) # Input placeholders x, y_ = input.input_placeholders(IMAGE_PIXELS, NUM_CLASSES) # Reshape images for visualization x_reshaped = tf.reshape(x, [-1, IMAGE_SIZE, IMAGE_SIZE, 1]) tf.image_summary('input', x_reshaped, NUM_CLASSES, name="y_input") """ Step 2 - Building the graph """ # Inference softmax = graph.create_inference_step(x=x, num_pixels=IMAGE_PIXELS, num_dense1=200, num_dense2=100, num_dense3=50, num_classes=NUM_CLASSES) # Loss cross_entropy = graph.add_loss_step(softmax, y_) # Train step train_step = graph.add_train_step(cross_entropy, FLAGS.learning_rate) """ Step 3 - Build the evaluation step """ # Model Evaluation accuracy = evaluation.evaluate(softmax, y_) """ Step 4 - Merge all summaries for TensorBoard generation """ # Merge all the summaries and write them out to /tmp/mnist_dense_logs (by default) merged = tf.merge_all_summaries() train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph) validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation') """ Step 5 - Train the model, and write summaries """ # Initialize all variables tf.initialize_all_variables().run() # All other steps, run train_step on training data, & add training summaries for i in range(FLAGS.max_steps): # Load next batch of data x_batch, y_batch = mnist.train.next_batch(FLAGS.batch_size) # Run summaries and train_step summary, _ = sess.run([merged, train_step], feed_dict={x: x_batch, y_: y_batch}) # Add summaries to train writer train_writer.add_summary(summary, i) # Every 10th step, measure validation-set accuracy, and write validation summaries if i % 10 == 0: # Run summaries and mesure accuracy on validation set summary, acc_valid = sess.run([merged, accuracy], feed_dict={x: mnist.validation.images, y_: mnist.validation.labels}) # Add summaries to validation writer validation_writer.add_summary(summary, i) print('Validation Accuracy at step %s: %s' % (i, acc_valid)) # Measure accuracy on test set acc_test = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('Accuracy on test set: %s' % acc_test)