Exemple #1
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    image_lists = dog_cat.create_image_lists(FLAGS.image_dir,
                                             FLAGS.testing_percentage,
                                             FLAGS.validation_percentage)
    class_count = len(image_lists.keys())
    if class_count == 0:
        tf.logging.error('No valid folders of images found at ' +
                         FLAGS.image_dir)
        return -1
    if class_count == 1:
        tf.logging.error('Only one valid folder of images found at ' +
                         FLAGS.image_dir +
                         ' - multiple classes are needed for classification.')
        return -1
    with tf.Session() as sess:
        (test_cached_tensor, test_ground_truth,
         _) = get_random_cached_bottlenecks(sess, image_lists,
                                            FLAGS.test_batch_size, 'testing',
                                            FLAGS.cache_dir, FLAGS.image_dir)
        logits = VGG16.vgg16_net(tf.convert_to_tensor(test_cached_tensor),
                                 class_count)
        correct = num_correct_prediction(logits, test_ground_truth)
        saver = tf.train.Saver(tf.global_variables())
        print("Reading checkpoints...")
        ckpt = tf.train.get_checkpoint_state('model')
        if ckpt and ckpt.model_checkpoint_path:
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('Loading success, global_step is %s' % global_step)
        else:
            print('No checkpoint file found')
            return
        print('\nEvaluating......')
        batch_correct = sess.run(correct)
        print('Total testing samples: %d' % len(test_ground_truth))
        print('Total correct predictions: %d' % batch_correct)
        print('Average accuracy: %.2f%%' %
              (100 * batch_correct / len(test_ground_truth)))
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
                                   FLAGS.validation_percentage)
    class_count = len(image_lists.keys())
    if class_count == 0:
        tf.logging.error('No valid folders of images found at ' + FLAGS.image_dir)
        return -1
    if class_count == 1:
        tf.logging.error('Only one valid folder of images found at ' +
                         FLAGS.image_dir +
                         ' - multiple classes are needed for classification.')
        return -1

    x = tf.placeholder(tf.float32, shape=[None, 224, 224, 3])
    y_ = tf.placeholder(tf.float32, shape=[None, class_count])
    logits = VGG16.vgg16_net(x,class_count)
    loss = caculateloss(logits, y_)
    acc = accuracy(logits, y_)
    my_global_step = tf.Variable(0, name='global_step', trainable=False)
    train_step = optimize(loss, FLAGS.learning_rate,my_global_step)

    #with graph.as_default():
    #     (train_step, cross_entropy,bottleneck_input,input_groud_truth) = train(output_tensor)
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        #print(tf.trainable_variables())
        load_with_skip('vgg16.npy', sess, ['fc6', 'fc7', 'fc8'])
        jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding()
        cache_bottleneck(sess, image_lists, FLAGS.image_dir,
                           FLAGS.cache_dir, jpeg_data_tensor,
                           decoded_image_tensor, x)


        # 评估预测准确率
        #evaluation_step, _ = add_evaluation_step(output_tensor, ground_truth_input)

        saver = tf.train.Saver(tf.global_variables())
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                             sess.graph)
        validation_writer = tf.summary.FileWriter(
            FLAGS.summaries_dir + '/validation')
        train_saver = tf.train.Saver()

        for i in range(FLAGS.how_many_training_steps):
            (train_cached_tensor,
             train_ground_truth, _) = get_random_cached_bottlenecks(
                sess, image_lists, FLAGS.train_batch_size, 'training',
                FLAGS.cache_dir, FLAGS.image_dir, jpeg_data_tensor,
                decoded_image_tensor, x)
            _, tra_loss, tra_acc = sess.run([train_step,loss,acc],feed_dict={x:train_cached_tensor, y_:train_ground_truth})
            print('Step: %d, loss: %.4f, accuracy: %.4f%%' % (i, tra_loss, tra_acc))
            is_last_step = (i + 1 == FLAGS.how_many_training_steps)
            # 训练完成或每完成eval_step_interval各batch训练,打印准确率和交叉熵
            if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
                (validation_cached_tensor,
                 validation_ground_truth, _) = get_random_cached_bottlenecks(
                    sess, image_lists, FLAGS.validation_batch_size, 'validation',
                    FLAGS.cache_dir, FLAGS.image_dir, jpeg_data_tensor,
                    decoded_image_tensor, x)
                val_loss, val_acc = sess.run([loss, acc],
                                             feed_dict={x: validation_cached_tensor, y_: validation_ground_truth})
                print('**  Step %d, val loss = %.2f, val accuracy = %.2f%%  **' % (i, val_loss, val_acc))
            if i % 2000 == 0 or is_last_step:
                checkpoint_path = os.path.join('model', 'model.ckpt')
                saver.save(sess, checkpoint_path,global_step=i)