def test_add_train_step(self):
        with tf.Graph().as_default():
            # Given
            loss = tf.Variable([1.0], dtype=tf.float32)

            # When
            train_op = graph.add_train_step(loss, 0.1)

            # Then
            self.assertIsNotNone(train_op)
            self.assertEqual(type(train_op).__name__, "Operation")
    def test_run_train(self):
        with tf.Graph().as_default():
            with self.test_session():
                # Given
                var0 = tf.Variable([1.0, 2.0], dtype=tf.float32)
                var1 = tf.Variable([3.0, 4.0], dtype=tf.float32)
                cost = 5 * var0 + 3 * var1

                # When
                train_op = graph.add_train_step(cost, 3.0)
                tf.initialize_all_variables().run()

                # Then
                # Fetch params to validate initial values
                self.assertAllClose([1.0, 2.0], var0.eval())
                self.assertAllClose([3.0, 4.0], var1.eval())

                # Run one training step
                train_op.run()

                # Validate updated params
                self.assertAllClose([-14., -13.], var0.eval())
                self.assertAllClose([-6., -5.], var1.eval())
示例#3
0
def main():

    # Create an InteractiveSession
    sess = tf.InteractiveSession()

    # Remove tensorboard previous directory
    if os.path.exists(FLAGS.summaries_dir):
        shutil.rmtree(FLAGS.summaries_dir)

    """
    Step 1 - Input data management
    """

    # MNIST data
    mnist = input.input_reader(FLAGS.data_dir)

    # Input placeholders
    x, y_ = input.input_placeholders(IMAGE_PIXELS, NUM_CLASSES)

    # Reshape images for visualization
    x_reshaped = tf.reshape(x, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
    tf.image_summary('input', x_reshaped, NUM_CLASSES, name="y_input")

    """
    Step 2 - Building the graph
    """

    # Inference
    softmax = graph.create_inference_step(x=x, num_pixels=IMAGE_PIXELS, num_classes=NUM_CLASSES)

    # Loss
    cross_entropy = graph.add_loss_step(softmax, y_)

    # Train step
    train_step = graph.add_train_step(cross_entropy, FLAGS.learning_rate)

    """
    Step 3 - Build the evaluation step
    """

    # Model Evaluation
    accuracy = evaluation.evaluate(softmax, y_)

    """
    Step 4 - Merge all summaries for TensorBoard generation
    """

    # Merge all the summaries and write them out to /tmp/mnist_dense_logs (by default)
    merged = tf.merge_all_summaries()
    train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph)
    validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation')

    """
    Step 5 - Train the model, and write summaries
    """

    # Initialize all variables
    tf.initialize_all_variables().run()

    # All other steps, run train_step on training data, & add training summaries
    for i in range(FLAGS.max_steps):

        # Load next batch of data
        x_batch, y_batch = mnist.train.next_batch(FLAGS.batch_size)

        # Run summaries and train_step
        summary, _ = sess.run([merged, train_step], feed_dict={x: x_batch, y_: y_batch})

        # Add summaries to train writer
        train_writer.add_summary(summary, i)

        # Every 10th step, measure validation-set accuracy, and write validation summaries
        if i % 10 == 0:
            # Run summaries and mesure accuracy on validation set
            summary, acc_valid = sess.run([merged, accuracy],
                                          feed_dict={x: mnist.validation.images, y_: mnist.validation.labels})

            # Add summaries to validation writer
            validation_writer.add_summary(summary, i)

            print('Validation Accuracy at step %s: %s' % (i, acc_valid))

    # Measure accuracy on test set
    acc_test = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
    print('Accuracy on test set: %s' % acc_test)