示例#1
0
            run_id = log_run_start('Train', hyper_dict, VERSION_DESCRIPTION)
            print()
            epoch_train_time_list = []
            training_losses = []
            validate_accuracies = []
            for i in range(epochs):
                t1 = time()
                X_TRAIN, Y_TRAIN = shuffle(X_TRAIN, Y_TRAIN)
                for offset in range(0, num_examples, batch_size):
                    end = offset + batch_size
                    batch_x, batch_y = X_TRAIN[offset:end], Y_TRAIN[offset:end]
                    _, loss = sess.run((training_operation, loss_operation), \
                        feed_dict={x: batch_x, y: batch_y, keep_prob: 0.5})

                if TRAIN_RATIO < 1.0:
                    validation_accuracy = evaluate(X_VALID, Y_VALID, batch_size)
                else:
                    validation_accuracy = None
                print("EPOCH {} ...".format(i+1))
                print("Training Loss = {:.3f}".format(loss))
                if validation_accuracy is not None:
                    print("Validation Accuracy = {:.3f}".format(validation_accuracy))
                print()
                epoch_train_time_list.append(time() - t1)
                training_losses.append(loss)
                validate_accuracies.append(validation_accuracy)

            saver.save(sess, save_file)
            print("Model saved")
            log_run_end(run_id, validation_accuracy, loss)
        print()
示例#2
0
import tensorflow as tf
from traffic_sign_model import evaluate

i = 1
fname = 'sign{:02d}.png'.format(i)
images = mpimg.imread(fname)
images = np.expand_dims(images, axis=0)
print("Image shape", images.shape)
for i in range(2,7):
    fname = 'sign{:02d}.png'.format(i)
    img = mpimg.imread(fname)
    img = np.expand_dims(img, axis=0)
    print("image %s shape" % (i), img.shape)
    images = np.append(images, img, axis=0)

#printing out some stats and plotting
print('The extra data shape is:', images.shape)

y_extra = [14,28,13,27,17,26]

saver = tf.train.Saver()

save_file = 'traffic_signs.ckpt'

# Evaluate the model with test data
with tf.Session() as sess:
    saver.restore(sess, save_file)

    test_accuracy = evaluate(images, y_extra, 6)
    print("Test Accuracy = {:.3f}".format(test_accuracy))