示例#1
0
def main(_):
  checkpoint_path=FLAGS.checkpoint_path
  img = load_image(FLAGS.input)

  labels = load_labels_from_file(FLAGS.dataset_dir)
  # num_classes = len(labels) + FLAGS.label_offset
  num_classes = len(labels)

  network_fn = nets_factory.get_network_fn(
    FLAGS.model_name,
    num_classes=num_classes,
    is_training=False)

  eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

  print("\nLoading Model")
  imgs0 = tf.placeholder(tf.uint8, [None, None, 3])
  imgs = preprocess_image(imgs0, eval_image_size)
  imgs = tf.expand_dims(imgs, 0)

  _,end_points = network_fn(imgs)
  # print(end_points)
	
  checkpoint_file = tf.train.latest_checkpoint(checkpoint_path)

  init_fn = slim.assign_from_checkpoint_fn(checkpoint_file, slim.get_variables_to_restore())

  print("\nFeedforwarding")

  with tf.Session() as sess:
    init_fn(sess)
    # print(init_fn(sess))

    ep = sess.run(end_points, feed_dict={imgs0: img})
    pred_layer_name = _layer_names[FLAGS.model_name][1]
    probs = ep[pred_layer_name][0]

    preds = (np.argsort(probs)[::-1])[0:5]
    print(preds)
    print('\nTop 5 classes are')
    for p in preds:
        # print(p-FLAGS.label_offset, labels[p-FLAGS.label_offset], probs[p])
        print(p, labels[p], probs[p])

    for i in range(5):
        # Target class
        predicted_class = preds[i]
        print(predicted_class)
        # Target layer for visualization
        layer_name = FLAGS.layer_name or _layer_names[FLAGS.model_name][0]
        # Number of output classes of model being used
        nb_classes = num_classes

        cam3 = grad_cam(img, imgs0, end_points, sess, predicted_class, layer_name, nb_classes, eval_image_size)

        img = cv2.resize(img, (eval_image_size, eval_image_size))
        img = img.astype(float)
        img /= img.max()


        cam3 = cv2.applyColorMap(np.uint8(255*cam3), cv2.COLORMAP_JET)
        cam3 = cv2.cvtColor(cam3, cv2.COLOR_BGR2RGB)

        # Superimposing the visualization with the image.
        alpha = 0.0025
        new_img = img+alpha*cam3
        new_img /= new_img.max()

        # Display and save
        io.imshow(new_img)
        plt.axis('off')
        plt.savefig(FLAGS.output, bbox_inches='tight')
        plt.show()
示例#2
0
 def get_network_fn(model_name, num_classes):
     return nets_factory.get_network_fn(model_name,
                                        num_classes=num_classes,
                                        is_training=False)