def test(): # build graph inputs, labels, dropout_keep_prob, learning_rate = model.input_placeholder(FLAGS.image_size, FLAGS.image_channel, FLAGS.label_cnt) logits = model.inference(inputs, dropout_keep_prob) predict = tf.argmax(logits, 1) # session init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) # tf saver saver = tf.train.Saver() if os.path.isfile(FLAGS.save_name): saver.restore(sess, FLAGS.save_name) i = 1 # load test data test_images, test_ranges = loader.load_mnist_test(FLAGS.batch_size) # ready for result file test_result_file = open(FLAGS.test_result, 'wb') csv_writer = csv.writer(test_result_file) csv_writer.writerow(['ImageId', 'Label']) total_start_time = time.time() for file_start, file_end in test_ranges: test_x = test_images[file_start:file_end] predict_label = sess.run(predict, feed_dict={inputs: test_x, dropout_keep_prob: 1.0}) for cur_predict in predict_label: csv_writer.writerow([i, cur_predict]) print('[Result %s: %s]' % (i, cur_predict)) i += 1 print("[%s][total exec %s seconds" % (time.strftime("%Y-%m-%d %H:%M:%S"), (time.time() - total_start_time)))
def train(): # build graph inputs, labels, dropout_keep_prob, learning_rate = model.input_placeholder(image_size, image_channel, label_cnt) logits = model.inference(inputs, dropout_keep_prob, label_cnt) accuracy = model.accuracy(logits, labels) loss = model.loss(logits, labels) train = tf.train.RMSPropOptimizer(learning_rate, FLAGS.rms_decay).minimize(loss) # session init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) # ready for summary merged = tf.merge_all_summaries() train_writer = tf.train.SummaryWriter('./summary/train', sess.graph) validation_writer = tf.train.SummaryWriter('./summary/validation') # tf saver saver = tf.train.Saver() if os.path.isfile(FLAGS.save_name): saver.restore(sess, FLAGS.save_name) total_start_time = time.time() # load mnist data train_images, train_labels, train_range, validation_images, validation_labels, validation_indices = loader.load_mnist_train( FLAGS.validation_size, FLAGS.batch_size) total_train_len = len(train_images) i = 0 cur_learning_rate = FLAGS.learning_rate for epoch in range(FLAGS.training_epoch): if epoch % 10 == 0 and epoch > 0: cur_learning_rate /= 10 epoch_start_time = time.time() for start, end in train_range: batch_start_time = time.time() train_x = train_images[start:end] train_y = train_labels[start:end] if i % 20 == 0: summary, _, loss_result = sess.run([merged, train, loss], feed_dict={inputs: train_x, labels: train_y, dropout_keep_prob: FLAGS.dropout_keep_prob, learning_rate: cur_learning_rate}) train_writer.add_summary(summary, i) else: _, loss_result = sess.run([train, loss], feed_dict={inputs: train_x, labels: train_y, dropout_keep_prob: FLAGS.dropout_keep_prob, learning_rate: cur_learning_rate}) print('[%s][training][epoch %d, step %d exec %.2f seconds] [file: %5d ~ %5d / %5d] loss : %3.10f' % ( time.strftime("%Y-%m-%d %H:%M:%S"), epoch, i, (time.time() - batch_start_time), start, end, total_train_len, loss_result)) if i % FLAGS.validation_interval == 0 and i > 0: validation_start_time = time.time() shuffle_indices = loader.shuffle_validation(validation_indices, FLAGS.batch_size) validation_x = validation_images[shuffle_indices] validation_y = validation_labels[shuffle_indices] summary, accuracy_result, loss_result = sess.run([merged, accuracy, loss], feed_dict={inputs: validation_x, labels: validation_y, dropout_keep_prob: 1.0}) validation_writer.add_summary(summary, i) print('[%s][validation][epoch %d, step %d exec %.2f seconds] accuracy : %1.3f, loss : %3.10f' % ( time.strftime("%Y-%m-%d %H:%M:%S"), epoch, i, (time.time() - validation_start_time), accuracy_result, loss_result)) i += 1 print("[%s][epoch exec %s seconds] epoch : %d" % ( time.strftime("%Y-%m-%d %H:%M:%S"), (time.time() - epoch_start_time), epoch)) saver.save(sess, FLAGS.save_name) print("[%s][total exec %s seconds" % (time.strftime("%Y-%m-%d %H:%M:%S"), (time.time() - total_start_time))) train_writer.close() validation_writer.close()
min_after_dequeue=80) #img_batch, label_batch, img_class_batch = tf.train.shuffle_batch([img, label, img_class], # batch_size=40, capacity=1000 + 3 * 40, # min_after_dequeue=1000) #img_batch, label_batch, img_class_batch = tf.train.shuffle_batch([img, label, img_class], # batch_size=4, capacity= 70, # min_after_dequeue=10) ### OK with v2 #img_batch, label_batch, img_class_batch = tf.train.shuffle_batch([img, label, img_class], # batch_size=5, capacity=40, # min_after_dequeue=30) inputs, labels, dropout_keep_prob, learning_rate = model.input_placeholder( image_size, image_channel, label_cnt) logits = model.inference(inputs, dropout_keep_prob, label_cnt) accuracy = model.accuracy(logits, labels) loss = model.loss(logits, labels) train = tf.train.RMSPropOptimizer(learning_rate, 0.9).minimize(loss) #初始化所有的op init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #启动队列 try: