def main(_):
    # Placeholders
    x = tf.placeholder(tf.float32, [FLAGS.batch_size, 224, 224, 3])
    y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    model = VggNetModel(num_classes=FLAGS.num_classes,
                        dropout_keep_prob=dropout_keep_prob)
    model.inference(x)

    # Accuracy of the model
    correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    saver = tf.train.Saver()
    test_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.test_file,
                                          num_classes=FLAGS.num_classes,
                                          output_size=[224, 224])
    test_batches_per_epoch = np.floor(
        len(test_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Directly restore (your model should be exactly the same with checkpoint)
        saver.restore(sess, FLAGS.ckpt)

        test_acc = 0.
        test_count = 0

        for _ in range(test_batches_per_epoch):
            batch_tx, batch_ty = test_preprocessor.next_batch(FLAGS.batch_size)
            acc = sess.run(accuracy,
                           feed_dict={
                               x: batch_tx,
                               y: batch_ty,
                               dropout_keep_prob: 1.
                           })
            test_acc += acc
            test_count += 1

        test_acc /= test_count
        print("{} Test Accuracy = {:.4f}".format(datetime.datetime.now(),
                                                 test_acc))
def main(_):
    # Placeholders
    x = tf.placeholder(tf.float32, [1, 224, 224, 3])
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    model = VggNetModel(num_classes=FLAGS.num_classes,
                        dropout_keep_prob=dropout_keep_prob)
    model.inference(x)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Directly restore (your model should be exactly the same with checkpoint)
        saver.restore(sess, FLAGS.ckpt)

        batch_x = np.ndarray([1, 224, 224, 3])

        # Read image and resize it
        img = cv2.imread(FLAGS.input_image)
        img = cv2.resize(img, (224, 224))
        img = img.astype(np.float32)

        # Subtract mean color
        img -= np.array([132.2766, 139.6506, 146.9702])

        batch_x[0] = img

        scores = sess.run(model.score,
                          feed_dict={
                              x: batch_x,
                              dropout_keep_prob: 1.
                          })
        print(scores)
Ejemplo n.º 3
0
def predict(path, modelpath):
    with tf.Graph().as_default():
        # Placeholders
        x = tf.placeholder(tf.float32, [1, 224, 224, 3])
        dropout_keep_prob = tf.placeholder(tf.float32)
        imgs = []
        # path='/home/ugrad/Shang/animal/1_.jpg'
        # image = cv2.imread(path,0)

        # cv2.imwrite(path,img)
        img = cv2.imread(path)
        img = cv2.resize(img, (224, 224))
        img = img.astype(np.float32)
        imgs.append(img)
        # img=Image.open(path)
        # img = np.array(img)
        # img = tf.cast(img, tf.float32)
        # img = tf.reshape(img, [1, 227, 227, 3])

        # Model
        model = VggNetModel(num_classes=FLAGS.num_classes,
                            dropout_keep_prob=dropout_keep_prob)
        logits = model.inference(x)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # Directly restore (your model should be exactly the same with checkpoint)
            # Load the pretrained weights
            saver = tf.train.Saver(tf.global_variables())
            saver.restore(sess, modelpath)
            prediction = sess.run(logits,
                                  feed_dict={
                                      x: imgs,
                                      dropout_keep_prob: 1.
                                  })
            # print(prediction)
            max_index = np.argmax(prediction)
            print(max_index)
        return max_index