示例#1
0
class TLDetector(object):
    def __init__(self):

        # ============== ROS specific stuff
        rospy.init_node('tl_detector')
        rospy.Subscriber('/current_pose', PoseStamped, self.cb_pose)
        rospy.Subscriber('/base_waypoints', Lane, self.cb_waypoints)
        rospy.Subscriber('/idx_closest_waypoint', std_msgs.msg.Int32,
                         self.cb_waypoint_next)
        rospy.Subscriber('/vehicle/traffic_lights', TrafficLightArray,
                         self.cb_traffic_lights)
        rospy.Subscriber('/image_color', Image, self.cb_image)
        self.pub_idx_wp_to_stop = rospy.Publisher('/traffic_waypoint',
                                                  Int32,
                                                  queue_size=1)

        # ========== other stuff
        # waypoints and egopose
        self.pose = None
        self.waypoints_2d = None
        self.waypoints_tree = None

        # traffic lights / stop lines
        config_string = rospy.get_param("/traffic_light_config")
        self.config = yaml.load(config_string)
        self.lights = None
        self.idx_light_next = 0

        # image classification
        self.cv_bridge = CvBridge()
        self.img_cnt = 0
        self.img_t_last = time.time()
        self.classifier = TLClassifier(filename_pb="model.pb")
        self.img_queue = []

        # start looping
        rospy.spin()

    def cb_pose(self, msg):
        self.pose = msg

    def cb_waypoints(self, msg_waypoints):
        if self.waypoints_2d is None:  # only do it once!
            # collect waypoints
            self.waypoints_2d = [[
                wp.pose.pose.position.x, wp.pose.pose.position.y
            ] for wp in msg_waypoints.waypoints]
            self.waypoints_tree = scipy.spatial.KDTree(self.waypoints_2d)

            # collect information about all traffic lights from config and waypoints
            self.lights = []
            lights_pos_xy = self.config['stop_line_positions']
            for idx_light, pos_xy in enumerate(lights_pos_xy):
                light = Light()
                light.idx = idx_light
                light.x = pos_xy[0]
                light.y = pos_xy[1]
                light.idx_wp = self.waypoints_tree.query(pos_xy, 1)[1]
                self.lights.append(light)

    def cb_waypoint_next(self, msg):
        idx_wp_next = msg.data
        # check whether next light is still ahead of current waypoint. Otherwise, get next light
        if self.lights is not None:
            light_next = self.lights[self.idx_light_next]
            if idx_wp_next > light_next.idx_wp:
                self.idx_light_next = (self.idx_light_next + 1) % len(
                    self.lights)

    def cb_traffic_lights(self, msg):
        if self.lights is not None:
            # copy information from msg into class object
            for idx, light_gt in enumerate(msg.lights):
                self.lights[idx].state_true = light_gt.state

    def cb_image(
        self,
        msg_img,
        num_images_skip=3,
        dt_min_between_images=0.050,
        flag_export=False,
    ):
        # only do stuff with image if next traffic light is close. Also don't process every image
        distance = self.get_distance_to_next_light()
        if distance is not None and distance < 100:
            self.img_cnt += 1
            time_now = time.time()
            if self.img_cnt % num_images_skip == 0 and time_now - self.img_t_last > dt_min_between_images:
                self.img_t_last = time_now

                # start of actual processing
                img_numpy = self.cv_bridge.imgmsg_to_cv2(msg_img, "bgr8")
                if flag_export:
                    self.export_image(img_numpy, distance)
                else:
                    state = self.predict_state(img_numpy)
                    if state is not None:
                        next_light = self.lights[self.idx_light_next]
                        next_light.state_pred = state
                        self.check_lights()

    def predict_state(
            self,
            img,
            batch_size=3,
            target_size=(320, 240),
            ratio_min_detection=0.6,
    ):
        # resize image, scale values, convert to RGB and  and append to queue
        img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
        img = img / 255.
        img = img[:, :, ::-1]  # converts BGR to RGB
        self.img_queue.append(img)

        # only if queue has batch_size items in it, run prediction
        if len(self.img_queue) == batch_size:
            input_tensor = np.asarray(self.img_queue)
            assert (input_tensor.ndim == 4)
            self.img_queue = []

            prob_all = self.classifier.predict(input_tensor)
            prob_red = prob_all[:, 0]
            if np.mean(prob_red > 0.5) > ratio_min_detection:
                return 0  # meaning traffic light = red
            elif np.mean(prob_red < 0.5) > ratio_min_detection:
                return 2  # meaning traffic light = green
            else:
                return None
        else:
            return None

    def export_image(self, img_numpy, distance):
        light_state = self.lights[
            self.idx_light_next].state_true  # state as int 0,1,2,3, see above
        time_in_ms = np.round(time.time() * 1000).astype(np.int)
        folder = os.path.join('/mnt/share/export', str(light_state))
        filename = (
            'img_' + str(time_in_ms)
            # + '_state_' + str(light_state)
            + '_lightidx_' + str(self.idx_light_next) + '_dist_' +
            str(np.round(distance).astype(np.int)) + '.png')
        filepath = os.path.join(folder, filename)
        cv2.imwrite(filepath, img_numpy)
        rospy.loginfo("Written image {}".format(filename))

    def check_lights(self):
        msg_out = Int32()
        msg_out.data = -1  # default value, meaning no need to stop at any waypoint
        distance = self.get_distance_to_next_light()
        if distance is not None:
            light_next = self.lights[self.idx_light_next]
            if light_next.state_pred == 0:  # 0=red, see above
                msg_out.data = light_next.idx_wp
            if True:
                rospy.loginfo(
                    "Light idx={}: state_pred={}, state_true={}, distance={}".
                    format(
                        self.idx_light_next,
                        COLOR_PER_INT[light_next.state_pred],
                        COLOR_PER_INT[light_next.state_true],
                        distance,
                    ))
        self.pub_idx_wp_to_stop.publish(msg_out)

    def get_distance_to_next_light(self):
        if self.pose and self.lights and self.idx_light_next is not None:
            car_xy = np.asarray(
                [self.pose.pose.position.x, self.pose.pose.position.y])
            light_next = self.lights[self.idx_light_next]
            light_next_xy = np.asarray([light_next.x, light_next.y])
            distance = self.calc_distance(car_xy, light_next_xy)
            return distance
        else:
            return None

    @classmethod
    def calc_distance(self, pt1, pt2):
        dist = pt1 - pt2
        dist_norm = np.linalg.norm(dist)
        return dist_norm
示例#2
0
class TLDetector(object):
    """ detect and classify traffic light

        @subscribed /base_waypoints:         the complete list of waypoints the car will be following
        @subscribed /current_pose:           the vehicle's current position
        @subscribed /image_color:            the image stream from the car's camera
        @subscribed /vehicle/traffic_lights: the exact location and status of all traffic lights in simulator
        
        @published  /traffic_waypoint:       the index of the waypoint for nearest upcoming red light's stop line
    """
    CAMERA_IMAGE_CLASSIFICATION_WPS = 53
    CAMERA_IMAGE_COLLECTION_AFTER_LINE_COUNT = 27

    def __init__(self):
        rospy.init_node('tl_detector')

        # load config params:
        config_string = rospy.get_param("/traffic_light_config")
        self.config = yaml.load(config_string)

        # state variables:
        self.pose = None
        self.waypoints = None
        self._waypoints_location = None
        self._waypoints_size = None
        self._waypoints_index = None
        self.camera_image = None
        self.lights = []

        self.state = TrafficLight.UNKNOWN
        self.last_state = TrafficLight.UNKNOWN
        self.last_wp = -1
        self.state_count = 0

        # image collector:
        self.after_stop_line_count = TLDetector.CAMERA_IMAGE_COLLECTION_AFTER_LINE_COUNT

        # classifier--subscriber:
        self.listener = tf.TransformListener()
        # classifier--format convertor:
        self.bridge = CvBridge()
        # classifier--pre-trained model:
        filenames = os.listdir('./light_classification/models')
        if not filenames:
            pass
        else:
            # model name pattern:
            FILENAME_PATTERN = re.compile('(\d+)-model-params.h5')

            # parse model timestamps:
            timestamps = [
                int(FILENAME_PATTERN.match(filename).group(1))
                for filename in filenames
            ]

            # identify latest model:
            _, latest_model_filename = max(zip(timestamps, filenames),
                                           key=lambda t: t[0])

            # load latest model:
            self.light_classifier = TLClassifier()
            self.light_classifier.load(
                os.path.join('./light_classification/models',
                             latest_model_filename))

        # subscribe:
        rospy.Subscriber('/base_waypoints', Lane, self.waypoints_cb)
        rospy.Subscriber('/current_pose', PoseStamped, self.pose_cb)
        rospy.Subscriber('/image_color', Image, self.image_cb)
        '''
        /vehicle/traffic_lights provides you with the location of the traffic light in 3D map space and
        helps you acquire an accurate ground truth data source for the traffic light
        classifier by sending the current color state of all traffic lights in the
        simulator. When testing on the vehicle, the color state will not be available. You'll need to
        rely on the position of the light and the camera image to predict it.
        '''
        rospy.Subscriber('/vehicle/traffic_lights', TrafficLightArray,
                         self.traffic_cb)

        # publish:
        self.upcoming_red_light_pub = rospy.Publisher('/traffic_waypoint',
                                                      Int32,
                                                      queue_size=1)

        rospy.spin()

    def waypoints_cb(self, waypoints):
        """ load base waypoints from system 

        Args:
            waypoints (list of Waypoint): reference trajectory as waypoints
        """
        if not self.waypoints:
            # load waypoints:
            self.waypoints = waypoints
            # build index upon waypoints:
            self._waypoints_location = np.array([[
                waypoint.pose.pose.position.x, waypoint.pose.pose.position.y
            ] for waypoint in waypoints.waypoints])
            self._waypoints_size, _ = self._waypoints_location.shape
            self._waypoints_index = KDTree(self._waypoints_location)

    def pose_cb(self, msg):
        """ parse ego vehicle pose

        Args:
            msg (PoseStamped): ego vehicle pose
        """
        self.pose = msg

    def image_cb(self, msg):
        """ identify red lights in the incoming camera image and publishes the index
            of the waypoint closest to the red light's stop line to /traffic_waypoint

        Args:
            msg (Image): image from car-mounted camera
        """
        # parse input:
        self.has_image = True
        self.camera_image = msg
        # process traffic lights:
        stop_line_waypoint_index, state = self.process_traffic_lights()
        '''
        Publish upcoming red lights at camera frequency.
        Each predicted state has to occur `STATE_COUNT_THRESHOLD` number
        of times till we start using it. Otherwise the previous stable state is
        used.
        '''
        if self.state != state:
            self.state_count = 0
            self.state = state
        elif self.state_count >= STATE_COUNT_THRESHOLD:
            self.last_state = self.state
            stop_line_waypoint_index = stop_line_waypoint_index if (
                state == TrafficLight.RED
                or state == TrafficLight.YELLOW) else -1
            self.last_wp = stop_line_waypoint_index
            self.upcoming_red_light_pub.publish(
                Int32(stop_line_waypoint_index))
        else:
            self.upcoming_red_light_pub.publish(Int32(self.last_wp))
        self.state_count += 1

    def traffic_cb(self, msg):
        """ parse traffic light status from telegram

        Args:
            msg (TrafficLightArray): list of all traffic light telegrams
        """
        self.lights = msg.lights

    def get_next_waypoint_index(self):
        """ get next waypoint index for ego vehicle
        """
        # ego vehicle location:
        ego_vehicle_location = np.array(
            [self.pose.pose.position.x, self.pose.pose.position.y])

        # closest waypoint
        _, closest_waypoint_index = self._waypoints_index.query(
            ego_vehicle_location)

        closest_waypoint_location = self._waypoints_location[
            closest_waypoint_index]
        previous_waypoint_location = self._waypoints_location[
            closest_waypoint_index - 1]

        # whether the closest waypoint is the next waypoint:
        is_next_waypoint = (np.dot(
            closest_waypoint_location - ego_vehicle_location,
            previous_waypoint_location - ego_vehicle_location) < 0.0)

        # next waypoint index:
        next_waypoint_index = closest_waypoint_index
        if not is_next_waypoint:
            next_waypoint_index = (closest_waypoint_index +
                                   1) % self._waypoints_size

        return next_waypoint_index

    def get_closest_waypoint(self, position):
        """ get closest waypoint index for stop line

        Args:
            position (Pose): ego vehicle pose
        """
        location = np.array(position)
        _, index = self._waypoints_index.query(location)

        return index

    def get_light_state_from_camera(self):
        """ Determines the closest traffic light state from image analysis
        """
        if (not self.has_image):
            self.prev_light_loc = None
            return TrafficLight.UNKNOWN

        # format as OpenCV:
        cv_image = self.bridge.imgmsg_to_cv2(self.camera_image, "bgr8")

        # preprocess:
        preprocessed_image = self.light_classifier.preprocess(cv_image)

        # predict:
        return self.light_classifier.predict(preprocessed_image[np.newaxis])

    def save_traffic_light_image(self, index, order, distance, state):
        """ Save traffic light image for offline training

        Args:
            index (Int): traffic light index
            order (str): 'before' or 'after'
            distance (Int): distance to incoming stop line
            state (TrafficLight.state): traffic light state
        """
        # format image:
        traffic_light_image = self.bridge.imgmsg_to_cv2(
            self.camera_image, "bgr8")
        preprocessed = self.light_classifier.preprocess(traffic_light_image)
        filename = "light_classification/traffic_light_images/{}--{}-{}--{}=={}-preprocessed.jpg".format(
            rospy.Time.now().to_nsec(), order, index, distance, state)
        cv2.imwrite(filename, preprocessed)

    def get_light_state_from_telegram(self, light):
        """ Determines the closest traffic light state from telegram broadcast

        Args:
            light (TrafficLight): traffic light status
        """
        return light.state

    def process_traffic_lights(self):
        """Finds closest visible traffic light, if one exists, and determines its
            location and color

        Returns:
            int: index of waypoint closes to the upcoming stop line for a traffic light (-1 if none exists)
            int: ID of traffic light color (specified in styx_msgs/TrafficLight)
        """
        closest_distance = self._waypoints_size
        closest_stop_line_index = None
        closest_stop_line_waypoint_index = None

        # list of positions that correspond to the line to stop in front of for a given intersection
        stop_line_positions = self.config['stop_line_positions']
        if (self.pose) and self.waypoints:
            # ego vehicle position:
            ego_vehicle_waypoint_index = self.get_next_waypoint_index()

            # identify closest stop line:
            for i, stop_line_position in enumerate(stop_line_positions):
                stop_line_waypoint_index = self.get_closest_waypoint(
                    stop_line_position)

                distance = stop_line_waypoint_index - ego_vehicle_waypoint_index
                if distance > 0 and distance < closest_distance:
                    closest_distance = distance
                    closest_stop_line_index = i
                    closest_stop_line_waypoint_index = stop_line_waypoint_index

        # if there is incoming stop line:
        if ((closest_stop_line_waypoint_index) and
            (closest_distance <= TLDetector.CAMERA_IMAGE_CLASSIFICATION_WPS
             or self.after_stop_line_count > 0)):
            order = "before"
            # ego vehicle just passed stop line:
            if closest_distance > TLDetector.CAMERA_IMAGE_CLASSIFICATION_WPS:
                if self.after_stop_line_count > 0:
                    self.after_stop_line_count -= 1
                    order = "after"
                    closest_stop_line_index -= 1
                    closest_distance = self.after_stop_line_count
            # ego vehicle is about to cross stop line:
            elif closest_distance <= 3 and self.after_stop_line_count <= 0:
                self.after_stop_line_count = TLDetector.CAMERA_IMAGE_COLLECTION_AFTER_LINE_COUNT

            # method 01: telegram:
            state_telegram = self.get_light_state_from_telegram(
                self.lights[closest_stop_line_index])
            # method 02: image analysis
            state_camera = self.get_light_state_from_camera()

            # image collection:
            if state_telegram != state_camera and state_camera != TrafficLight.UNKNOWN:
                # save for hard negative mining:
                self.save_traffic_light_image(closest_stop_line_index, order,
                                              closest_distance, state_telegram)
                # prompt:
                rospy.logwarn(
                    "[Discrepancy between Telegram and Camera]: %d--%d @ %d, Camera Image Saved",
                    state_telegram, state_camera, closest_stop_line_index)

            return closest_stop_line_waypoint_index, state_camera
        '''
        # save for hard negative mining:
        self.save_traffic_light_image(
            0, "unknown", 0, 0
        )
        rospy.logwarn("test site image saved.")
        '''

        return -1, TrafficLight.UNKNOWN