예제 #1
0
class TLClassifier:
    def __init__(self, is_site):
        #TODO load classifier
        assert not is_site
        weights_file = r'light_classification/models/squeezenet_weights.h5'  #Replace with real world classifier

        image_shape = (224, 224, 3)

        self.states = (TrafficLight.RED, TrafficLight.YELLOW,
                       TrafficLight.GREEN, TrafficLight.UNKNOWN)

        print('Loading model..')
        self.model = SqueezeNet(len(self.states), *image_shape)
        self.model.load_weights(weights_file, by_name=True)
        self.model._make_predict_function()
        print('Loaded weights: %s' % weights_file)

    def get_classification(self, image):
        """Determines the color of the traffic light in the image

        Args:
            image (cv::Mat): image containing the traffic light

        Returns:
            int: ID of traffic light color (specified in styx_msgs/TrafficLight)

        """
        mini_batch = cv2.resize(
            image, (224, 224),
            cv2.INTER_AREA).astype('float')[np.newaxis, ..., ::-1] / 255.
        light = self.states[np.argmax(self.model.predict(mini_batch))]

        return light
예제 #2
0
import tensorflow as tf
from tensorflow.python.keras.preprocessing import image
from tensorflow.python.keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np

from squeezenet import SqueezeNet

responses = ["record_response", "replay_response"]

session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
                              inter_op_parallelism_threads=1)
sess = tf.Session(config=session_conf)

img = image.load_img('/image.jpg', target_size=(227, 227))
model = SqueezeNet(weights='imagenet')
model._make_predict_function()
print('Model is ready')

img2 = image.load_img('/image2.jpg', target_size=(227, 227))
model2 = SqueezeNet(weights='imagenet')
model2._make_predict_function()
print('Model2 is ready')


class Greeter(helloworld_pb2_grpc.GreeterServicer):
    def SayHello(self, request, context):
        #res = decode_predictions(preds) # requires access to the Internet
        if request.name == "record":
            msg = 'Hello, %s!' % responses[0]
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)