Пример #1
0
    def test_input_placeholders(self):
        with tf.Graph().as_default():
            # Given
            image_pixels = 100
            num_classes = 10

            # When
            x, y_ = input.input_placeholders(image_pixels, num_classes)

            # Then
            self.assertIsNotNone(x)
            self.assertEqual(type(x).__name__, "Tensor")
            self.assertEqual(x.__dict__["_shape"].as_list(), [None, 100])
            self.assertIsNotNone(y_)
            self.assertEqual(type(y_).__name__, "Tensor")
            self.assertEqual(y_.__dict__["_shape"].as_list(), [None, 10])
Пример #2
0
def main():

    # Create an InteractiveSession
    sess = tf.InteractiveSession()

    # Remove tensorboard previous directory
    if os.path.exists(FLAGS.summaries_dir):
        shutil.rmtree(FLAGS.summaries_dir)

    """
    Step 1 - Input data management
    """

    # MNIST data
    mnist = input.input_reader(FLAGS.data_dir)

    # Input placeholders
    x, y_ = input.input_placeholders(IMAGE_PIXELS, NUM_CLASSES)

    # Reshape images for visualization
    x_reshaped = tf.reshape(x, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
    tf.image_summary('input', x_reshaped, NUM_CLASSES, name="y_input")

    """
    Step 2 - Building the graph
    """

    # Inference
    softmax = graph.create_inference_step(x=x, num_pixels=IMAGE_PIXELS, num_dense1=200, num_dense2=100, num_dense3=50,
                                          num_classes=NUM_CLASSES)

    # Loss
    cross_entropy = graph.add_loss_step(softmax, y_)

    # Train step
    train_step = graph.add_train_step(cross_entropy, FLAGS.learning_rate)

    """
    Step 3 - Build the evaluation step
    """

    # Model Evaluation
    accuracy = evaluation.evaluate(softmax, y_)

    """
    Step 4 - Merge all summaries for TensorBoard generation
    """

    # Merge all the summaries and write them out to /tmp/mnist_dense_logs (by default)
    merged = tf.merge_all_summaries()
    train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph)
    validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation')

    """
    Step 5 - Train the model, and write summaries
    """

    # Initialize all variables
    tf.initialize_all_variables().run()

    # All other steps, run train_step on training data, & add training summaries
    for i in range(FLAGS.max_steps):

        # Load next batch of data
        x_batch, y_batch = mnist.train.next_batch(FLAGS.batch_size)

        # Run summaries and train_step
        summary, _ = sess.run([merged, train_step], feed_dict={x: x_batch, y_: y_batch})

        # Add summaries to train writer
        train_writer.add_summary(summary, i)

        # Every 10th step, measure validation-set accuracy, and write validation summaries
        if i % 10 == 0:
            # Run summaries and mesure accuracy on validation set
            summary, acc_valid = sess.run([merged, accuracy],
                                          feed_dict={x: mnist.validation.images, y_: mnist.validation.labels})

            # Add summaries to validation writer
            validation_writer.add_summary(summary, i)

            print('Validation Accuracy at step %s: %s' % (i, acc_valid))

    # Measure accuracy on test set
    acc_test = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
    print('Accuracy on test set: %s' % acc_test)