コード例 #1
0
tf.initialize_all_variables().run()
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

saver = tf.train.Saver()
best_model_accuracy = -1.0
best_model_epoch = -1
no_improvement_count = 0

print("Training for " + str(EPOCHS) + " epochs ...")
for epoch in range(0, EPOCHS):
    train_accuracies = []
    train_losses = []
    val_accuracies = []
    train_minibatchgen.shuffle()
    for i in range(0, train_minibatchgen.nbatches()):
        batch_data, batch_labels, _ = train_minibatchgen.batch(i)
        batch_labels_one_hot = np.eye(nclasses)[batch_labels]

        sess.run(train_step,
                 feed_dict={
                     x: batch_data,
                     y_: batch_labels_one_hot
                 })

        train_accuracy, train_loss = sess.run([accuracy, total_loss],
                                              feed_dict={
                                                  x: batch_data,
                                                  y_: batch_labels_one_hot
                                              })
コード例 #2
0
print ("Dataset has "+str(train.size())+" samples")
print ("Batch generator has "+str(minibatchgen.nbatches())+" minibatches, minibatch size: "+str(bs))
print ()

data, labels, indexes=minibatchgen.batch(0)
print ("Minibatch #0 has "+str(len(data))+" samples")
print (" Data shape: "+str(data.shape))
print (" First 10 sample IDs: "+str(indexes[:10]).strip(']').strip('[').replace(' ',''))
data, labels, indexes=minibatchgen.batch(66)
print ("Minibatch #66 has "+str(len(data))+" samples")
print (" First 10 sample IDs: "+str(indexes[:10]).strip(']').strip('[').replace(' ',''))
print ()

print ("Shuffling samples")
print ()
minibatchgen.shuffle()

data, labels, indexes=minibatchgen.batch(0)
print ("Minibatch #0 has "+str(len(data))+" samples")
print (" First 10 sample IDs: "+str(indexes[:10]).strip(']').strip('[').replace(' ',''))
data, labels, indexes=minibatchgen.batch(66)
print ("Minibatch #66 has "+str(len(data))+" samples")
print (" First 10 sample IDs: "+str(indexes[:10]).strip(']').strip('[').replace(' ',''))
print ()

print ("=== Testing with ImageVectorizer ===")
print ()

train = TinyCifar10Dataset(cifar10batchesdir, 'train')
train_vectorized = ImageVectorizer(train)
minibatchgen=MiniBatchGenerator(train_vectorized, 60)