Exemplo n.º 1
0
def predict(image, sess, softmax_tensor):
    """
    Function used by classification workers to get prediction on image.

    This method was adapted based on run_inference_on_image() method from classify_image.py found in TensorFlow official tutorial.

    :param image: filename of the image to be classified
    :param sess: TensorFlow session
    :param softmax_tensor: tensor used for computing the predictions
    :return: (img_id, result) with img_id being Instagram Image ID and the result being dictionary with 5 most probable
    objects depicted in the image as keys and corresponding prediction confidences as values.
    """
    img_id = os.path.splitext(os.path.basename(image))[0]
    image_data = tf.gfile.FastGFile(image, 'rb').read()

    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    predictions = sess.run(softmax_tensor,
                           {'DecodeJpeg/contents:0': image_data})
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup()

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    result = {}
    for node_id in top_k:
        human_string = node_lookup.id_to_string(node_id)
        score = predictions[node_id]
        result[human_string] = float(score)

    return img_id, result
Exemplo n.º 2
0
    def run_inference_on_image(self, imageList, num_top_predictions=5):
        """Runs inference on an image.

        Args:
          image: Image file name.

        Returns:
          Nothing
        """

        for image in imageList:
            if not tf.gfile.Exists(image):
                tf.logging.fatal('File does not exist %s', image)
            image_data = tf.gfile.FastGFile(image, 'rb').read()

            # Creates graph from saved GraphDef.
            self.__create_graph()

            with tf.Session() as sess:
                # Some useful tensors:
                # 'softmax:0': A tensor containing the normalized prediction across
                #   1000 labels.
                # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
                #   float description of the image.
                # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
                #   encoding of the image.
                # Runs the softmax tensor by feeding the image_data as input to the graph.
                softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
                predictions = sess.run(softmax_tensor,
                                       {'DecodeJpeg/contents:0': image_data})
                predictions = np.squeeze(predictions)

                # Creates node ID --> English string lookup.
                node_lookup = NodeLookup()

                top_k = predictions.argsort()[-num_top_predictions:][::-1]
                for node_id in top_k:
                    human_string = node_lookup.id_to_string(node_id)
                    score = predictions[node_id]
                    print('%s (score = %.5f)' % (human_string, score))