def test_add_train_step(self): with tf.Graph().as_default(): # Given loss = tf.Variable([1.0], dtype=tf.float32) # When train_op = graph.add_train_step(loss, 0.1) # Then self.assertIsNotNone(train_op) self.assertEqual(type(train_op).__name__, "Operation")
def test_run_train(self): with tf.Graph().as_default(): with self.test_session(): # Given var0 = tf.Variable([1.0, 2.0], dtype=tf.float32) var1 = tf.Variable([3.0, 4.0], dtype=tf.float32) cost = 5 * var0 + 3 * var1 # When train_op = graph.add_train_step(cost, 3.0) tf.initialize_all_variables().run() # Then # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([3.0, 4.0], var1.eval()) # Run one training step train_op.run() # Validate updated params self.assertAllClose([-14., -13.], var0.eval()) self.assertAllClose([-6., -5.], var1.eval())
def main(): # Create an InteractiveSession sess = tf.InteractiveSession() # Remove tensorboard previous directory if os.path.exists(FLAGS.summaries_dir): shutil.rmtree(FLAGS.summaries_dir) """ Step 1 - Input data management """ # MNIST data mnist = input.input_reader(FLAGS.data_dir) # Input placeholders x, y_ = input.input_placeholders(IMAGE_PIXELS, NUM_CLASSES) # Reshape images for visualization x_reshaped = tf.reshape(x, [-1, IMAGE_SIZE, IMAGE_SIZE, 1]) tf.image_summary('input', x_reshaped, NUM_CLASSES, name="y_input") """ Step 2 - Building the graph """ # Inference softmax = graph.create_inference_step(x=x, num_pixels=IMAGE_PIXELS, num_classes=NUM_CLASSES) # Loss cross_entropy = graph.add_loss_step(softmax, y_) # Train step train_step = graph.add_train_step(cross_entropy, FLAGS.learning_rate) """ Step 3 - Build the evaluation step """ # Model Evaluation accuracy = evaluation.evaluate(softmax, y_) """ Step 4 - Merge all summaries for TensorBoard generation """ # Merge all the summaries and write them out to /tmp/mnist_dense_logs (by default) merged = tf.merge_all_summaries() train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph) validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation') """ Step 5 - Train the model, and write summaries """ # Initialize all variables tf.initialize_all_variables().run() # All other steps, run train_step on training data, & add training summaries for i in range(FLAGS.max_steps): # Load next batch of data x_batch, y_batch = mnist.train.next_batch(FLAGS.batch_size) # Run summaries and train_step summary, _ = sess.run([merged, train_step], feed_dict={x: x_batch, y_: y_batch}) # Add summaries to train writer train_writer.add_summary(summary, i) # Every 10th step, measure validation-set accuracy, and write validation summaries if i % 10 == 0: # Run summaries and mesure accuracy on validation set summary, acc_valid = sess.run([merged, accuracy], feed_dict={x: mnist.validation.images, y_: mnist.validation.labels}) # Add summaries to validation writer validation_writer.add_summary(summary, i) print('Validation Accuracy at step %s: %s' % (i, acc_valid)) # Measure accuracy on test set acc_test = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('Accuracy on test set: %s' % acc_test)