Exemplo n.º 1
0
# Load Data
mnist = input_data.read_data_sets('mnist/', reshape=False)
x_test, y_test = mnist.test.images, mnist.test.labels
print('----------%d testing samples' % (x_test.shape[0]))
print('----------image size: {}'.format(x_test[0].shape))

# Padding
x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant')
print('----------new image size: {}'.format(x_test[0].shape))

# Hyper-parameters
b_size = 128

x = tf.placeholder(tf.float32, (None, 32, 32, 1))
y = tf.placeholder(tf.int32, (None))
y_one_hot = tf.one_hot(y, 10)
out = model.LeNet_2(x)

is_corr = tf.equal(tf.math.argmax(out, 1), tf.math.argmax(y_one_hot, 1))
acc_opr = tf.reduce_mean(tf.cast(is_corr, tf.float32))
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('.'))

    start = time.time()
    test_acc = func.evaluate(x_test, y_test, b_size, acc_opr, x, y)
    test_time = time.time() - start

    print('----------test_acc = {:.3f}, test_time = {:.3f} s'.format(
        test_acc, test_time))
is_corr = tf.equal(tf.math.argmax(out, 1), tf.math.argmax(y_one_hot, 1))
acc_opr = tf.reduce_mean(tf.cast(is_corr, tf.float32))
saver = tf.train.Saver()

total_time = 0
acc_temp = 0
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(epochs):
        start = time.time()

        x_train, y_train = shuffle(x_train, y_train)
        for j in range(0, len(x_train), b_size):
            b_x, b_y = x_train[j:j + b_size], y_train[j:j + b_size]
            sess.run(train_opr, feed_dict={x: b_x, y: b_y})
        train_time = time.time() - start
        total_time += train_time

        train_acc = func.evaluate(x_train, y_train, b_size, acc_opr, x, y)
        val_acc = func.evaluate(x_val, y_val, b_size, acc_opr, x, y)
        if val_acc > acc_temp:
            acc_temp = val_acc
            saver.save(sess, './model')

        print(
            '----------epoch {}/{}: train_acc = {:.3f}, val_acc = {:.3f}, train_time = {:.3f} s'
            .format(i + 1, epochs, train_acc, val_acc, train_time))

    print('----------mean epoch time = {:.3f} s'.format(total_time / epochs))