def main(_): tf.logging.set_verbosity(tf.logging.INFO) image_lists = dog_cat.create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, FLAGS.validation_percentage) class_count = len(image_lists.keys()) if class_count == 0: tf.logging.error('No valid folders of images found at ' + FLAGS.image_dir) return -1 if class_count == 1: tf.logging.error('Only one valid folder of images found at ' + FLAGS.image_dir + ' - multiple classes are needed for classification.') return -1 with tf.Session() as sess: (test_cached_tensor, test_ground_truth, _) = get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size, 'testing', FLAGS.cache_dir, FLAGS.image_dir) logits = VGG16.vgg16_net(tf.convert_to_tensor(test_cached_tensor), class_count) correct = num_correct_prediction(logits, test_ground_truth) saver = tf.train.Saver(tf.global_variables()) print("Reading checkpoints...") ckpt = tf.train.get_checkpoint_state('model') if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split( '-')[-1] saver.restore(sess, ckpt.model_checkpoint_path) print('Loading success, global_step is %s' % global_step) else: print('No checkpoint file found') return print('\nEvaluating......') batch_correct = sess.run(correct) print('Total testing samples: %d' % len(test_ground_truth)) print('Total correct predictions: %d' % batch_correct) print('Average accuracy: %.2f%%' % (100 * batch_correct / len(test_ground_truth)))
def main(_): tf.logging.set_verbosity(tf.logging.INFO) image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, FLAGS.validation_percentage) class_count = len(image_lists.keys()) if class_count == 0: tf.logging.error('No valid folders of images found at ' + FLAGS.image_dir) return -1 if class_count == 1: tf.logging.error('Only one valid folder of images found at ' + FLAGS.image_dir + ' - multiple classes are needed for classification.') return -1 x = tf.placeholder(tf.float32, shape=[None, 224, 224, 3]) y_ = tf.placeholder(tf.float32, shape=[None, class_count]) logits = VGG16.vgg16_net(x,class_count) loss = caculateloss(logits, y_) acc = accuracy(logits, y_) my_global_step = tf.Variable(0, name='global_step', trainable=False) train_step = optimize(loss, FLAGS.learning_rate,my_global_step) #with graph.as_default(): # (train_step, cross_entropy,bottleneck_input,input_groud_truth) = train(output_tensor) with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) #print(tf.trainable_variables()) load_with_skip('vgg16.npy', sess, ['fc6', 'fc7', 'fc8']) jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding() cache_bottleneck(sess, image_lists, FLAGS.image_dir, FLAGS.cache_dir, jpeg_data_tensor, decoded_image_tensor, x) # 评估预测准确率 #evaluation_step, _ = add_evaluation_step(output_tensor, ground_truth_input) saver = tf.train.Saver(tf.global_variables()) merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) validation_writer = tf.summary.FileWriter( FLAGS.summaries_dir + '/validation') train_saver = tf.train.Saver() for i in range(FLAGS.how_many_training_steps): (train_cached_tensor, train_ground_truth, _) = get_random_cached_bottlenecks( sess, image_lists, FLAGS.train_batch_size, 'training', FLAGS.cache_dir, FLAGS.image_dir, jpeg_data_tensor, decoded_image_tensor, x) _, tra_loss, tra_acc = sess.run([train_step,loss,acc],feed_dict={x:train_cached_tensor, y_:train_ground_truth}) print('Step: %d, loss: %.4f, accuracy: %.4f%%' % (i, tra_loss, tra_acc)) is_last_step = (i + 1 == FLAGS.how_many_training_steps) # 训练完成或每完成eval_step_interval各batch训练,打印准确率和交叉熵 if (i % FLAGS.eval_step_interval) == 0 or is_last_step: (validation_cached_tensor, validation_ground_truth, _) = get_random_cached_bottlenecks( sess, image_lists, FLAGS.validation_batch_size, 'validation', FLAGS.cache_dir, FLAGS.image_dir, jpeg_data_tensor, decoded_image_tensor, x) val_loss, val_acc = sess.run([loss, acc], feed_dict={x: validation_cached_tensor, y_: validation_ground_truth}) print('** Step %d, val loss = %.2f, val accuracy = %.2f%% **' % (i, val_loss, val_acc)) if i % 2000 == 0 or is_last_step: checkpoint_path = os.path.join('model', 'model.ckpt') saver.save(sess, checkpoint_path,global_step=i)