예제 #1
0
 def process_frame(self):
     """ Main function to process the frame and run the infererence """
     # Deque the next image msg
     current_msg = self.msg_queue.get()
     current_image = None
     # Convert to image to OpenCV format
     try:
         current_image = self._bridge.imgmsg_to_cv2(current_msg, "bgr8")
         rospy.logdebug("[trt_yolo_ros] image converted for processing")
     except CvBridgeError as e:
         rospy.logdebug("Failed to convert image %s", str(e))
     # Initialize detection results
     if current_image is not None:
         rospy.logdebug("[trt_yolo_ros] processing frame")
         boxes, classes, scores, visualization = self.model(current_image)
         detection_results = BoundingBoxes()
         detection_results.header = current_msg.header
         detection_results.image_header = current_msg.header
         # construct message
         self._write_message(detection_results, boxes, scores, classes)
         # send message
         try:
             rospy.logdebug("[trt_yolo_ros] publishing")
             self._pub.publish(detection_results)
             if self.publish_image:
                 self._pub_viz.publish(
                     self._bridge.cv2_to_imgmsg(visualization, "bgr8")
                 )
         except CvBridgeError as e:
             rospy.logdebug("[trt_yolo_ros] Failed to convert image %s", str(e))
예제 #2
0
    def camera_image_callback(self, image, camera):
        """Gets images from camera to generate detections on

        Computes where the bounding boxes should be in the image, and
        fakes YOLO bounding boxes output as well as publishing a debug
        image showing where the detections are if turned on

        Parameters
        ----------
        image : sensor_msgs.Image
            The image from the camera to create detections for
        camera : Camera
            Holds camera parameters for projecting points

        """

        cv_image = self.bridge.imgmsg_to_cv2(image, desired_encoding="rgb8")

        if camera.is_ready():

            # Fake darknet yolo detections message
            bounding_boxes = BoundingBoxes()

            bounding_boxes.header = image.header
            bounding_boxes.image_header = image.header
            bounding_boxes.image = image
            bounding_boxes.bounding_boxes = []

            for _, obj in self.objects.iteritems():

                detections = self.get_detections(obj, camera, image.header.stamp)

                if detections is not None:

                    if camera.debug_image_pub:

                        self.render_debug_context(cv_image, detections, camera)

                    for det in detections:

                        bounding_box = BoundingBox()

                        bounding_box.Class = det[2]
                        bounding_box.probability = 1.0
                        bounding_box.xmin = int(det[0][0])
                        bounding_box.xmax = int(det[0][1])
                        bounding_box.ymin = int(det[0][2])
                        bounding_box.ymax = int(det[0][3])
                        bounding_boxes.bounding_boxes.append(bounding_box)

            # Only publish detection if there are boxes in it
            if bounding_boxes.bounding_boxes:
                self.darknet_detection_pub.publish(bounding_boxes)

        if camera.debug_image_pub:
            image_message = self.bridge.cv2_to_imgmsg(cv_image, encoding="rgb8")
            camera.debug_image_pub.publish(image_message)
예제 #3
0
    def image_callback(self, msg):
        # print("Received an image!")
        try:
            # Convert your ROS Image message to numpy image data type
            cv2_img = np.frombuffer(msg.data, dtype=np.uint8).reshape(
                msg.height, msg.width, -1)
            image_ori = cv2.resize(cv2_img, (512, 512))
            #remove alpha channel if it exists
            if image_ori.shape[-1] == 4:
                image_ori = image_ori[..., :3]
            #normalize the image
            image = self.normalize(image_ori)

            with torch.no_grad():
                image = torch.Tensor(image)
                if torch.cuda.is_available():
                    image = image.cuda()
                boxes = self.model_.get_boxes(
                    image.permute(2, 0, 1).unsqueeze(0))

            Boxes_msg = BoundingBoxes()
            Boxes_msg.image_header = msg.header
            Boxes_msg.header.stamp = rospy.Time.now()
            for box in boxes[0]:
                # print(box.confidence)
                confidence = float(box.confidence)
                box = (box.box * torch.Tensor([512] * 4)).int().tolist()
                if confidence > 0.35:
                    cv2.rectangle(image_ori, (box[0], box[1]),
                                  (box[2], box[3]), (255, 0, 0), 2)
                    cv2.putText(image_ori,
                                str(confidence)[:4], (box[0] - 2, box[1] - 2),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                    # msg_frame = self.bridge.cv2_to_imgmsg(image_ori)
                    detection_box = BoundingBox()
                    # detection_box.Class=str(box.class_id)
                    detection_box.xmin = box[0]
                    detection_box.ymin = box[1]
                    detection_box.xmax = box[2]
                    detection_box.ymax = box[3]
                    detection_box.probability = confidence
                    Boxes_msg.bounding_boxes.append(detection_box)

            msg_frame = self.bridge.cv2_to_imgmsg(image_ori)
            self.img_pub.publish(msg_frame)
            self.ret_pub.publish(Boxes_msg)

        except CvBridgeError:
            print("error")
        if self.savefigure:
            cv2.imwrite('new_image.jpeg', cv2_img)
            print('save picture')
            # self.detected_msg.data="Take a photo"
            self.savefigure = False
예제 #4
0
 def process_frame(self):
     """ Main function to process the frame and run the infererence """
     # Deque the next image msg
     current_msg = self.msg_queue.get()
     current_image = None
     # Convert to image to OpenCV format
     try:
         current_image = self._bridge.imgmsg_to_cv2(current_msg, "bgr8")
         rospy.logdebug("[trt_yolo_ros] image converted for processing")
     except CvBridgeError as e:
         rospy.logdebug("Failed to convert image %s" , str(e))
     # Initialize detection results
     if current_image is not None:
         boxes, classes, scores, visualization = self.model(current_image)
         detection_results = BoundingBoxes()
         detection_results.header = current_msg.header
         detection_results.image_header = current_msg.header
         rospy.logdebug("[trt_yolo_ros] processing frame")
         # construct message
         self._write_message(detection_results, boxes, scores, classes)
         # send message
         try:
             rospy.logdebug("[trt_yolo_ros] publishing")
             self._pub.publish(detection_results)
             if self.publish_image and boxes is not None:
                 #compressed_msg = CompressedImage()
                 #compressed_msg.header.stamp = current_msg.header.stamp
                 #compressed_msg.format = "jpeg"
                 #print("Generating data")
                 #compressed_msg.data = np.array(cv2.imencode('.jpg', visualization)).tostring()
                 #print("Going to publish")
                 #self._pub_viz.publish(compressed_msg)
                 compressed_image = self._bridge.cv2_to_compressed_imgmsg(visualization, dst_format='jpg')
                 compressed_image.header.stamp = current_msg.header.stamp
                 self._pub_viz.publish(compressed_image)
                 self._pub_viz.publish(self._bridge.cv2_to_compressed_imgmsg(visualization, dst_format='jpg'))
                 #self._pub_viz.publish(self._bridge.cv2_to_imgmsg(visualization, "bgr8"))
         except CvBridgeError as e:
             rospy.logdebug("Failed to convert image %s" , str(e))
예제 #5
0
    def process_frames(self):
        print("\nMining Frames for Flicker Detection ......\o/ \o/ \o/ \o/....... \n")
        rosbag_files = sorted(glob.glob(os.path.join(self.bag_dir, "*.bag")))
        # Init extraction loop
        msg_count = 0


        # Iterate through bags
        rosbag_files = sorted(glob.glob(os.path.join(self.bag_dir, "*.bag")))
        for bag_file in tqdm(rosbag_files, unit='bag'):
            self.frame_imgs, self.trk_table, self.frame_count = self.load_cache(bag_file)
            self.bag_keypath = '_'.join(os.path.normpath(bag_file).split(os.sep))

            if not self.frame_imgs or not self.trk_table:
                # Open bag file. If corrupted, skip the file
                try:
                    with rosbag.Bag(bag_file, 'r') as bag:
                        # Check if desired topics exists in bags
                        recorded_topics = bag.get_type_and_topic_info()[1]
                        if not all(topic in recorded_topics for topic in ( DETECTION_TOPIC, CAR_STATE_TOPIC, IMAGE_STREAM_TOPIC)):
                            print("ERROR: Specified topics not in bag file:", bag_file, ".Skipping bag!")
                            continue

                        gps = ""
                        v_ego = 0.0
                        self.cv_image = None
                        # Get Detections
                        time_nsecs = []
                        self.trk_table = OrderedDict()
                        self.frame_imgs = OrderedDict()
                        self.frame_trks = OrderedDict()
                        self.flickering_frame_imgs = OrderedDict()
                        self.flickering_frame_trks = OrderedDict()
                        for topic, msg, t in bag.read_messages():
                            if topic == CAR_STATE_TOPIC:
                                self.gps = msg.GPS
                                self.v_ego = msg.v_ego
                            if topic == IMAGE_STREAM_TOPIC:
                                if msg_count % self.skip_rate == 0:

                                    self.cv_image = self.bridge.compressed_imgmsg_to_cv2(msg, "bgr8")
                                    self.img_msg_header = msg.header
                                    image_name = "frame%010d_%s.jpg" % (self.frame_count, str(msg.header.stamp.to_nsec()))
                                    img_path = os.path.join(self.frames_dir, image_name)
                                    cv2.imwrite(img_path, self.cv_image)

                                    uri = img_path
                                    img_key = int(msg.header.stamp.to_nsec())
                                    fname = self.path_leaf(uri)
                                    vid_name = ''
                                    readable_time = self.format_from_nanos(img_key).split(' ')
                                    if ':' in readable_time and ' ' not in readable_time.split(':')[0]:
                                        hour = int(readable_time.split(':')[0])
                                    else:
                                        hour = 12

                                    if (hour > 4 and hour < 6) or (hour > 17 and hour < 19):
                                        timeofday = 'dawn/dusk'
                                    elif hour > 6 and hour < 17:
                                        timeofday = 'daytime'
                                    else:
                                        timeofday = 'night'


                                    scene = 'highway'
                                    timestamp = img_key
                                    dataset_path = img_path
                                    if self.gps:
                                        gps = pynmea2.parse(self.gps)
                                        lat = gps.lat
                                        long = gps.lon
                                    else:
                                        lat = None
                                        long = None
                                    height, width, depth = self.cv_image.shape

                                    self.frame_imgs[img_key] = {'url': img_path,
                                                             'name': img_path,
                                                             'width': width,
                                                             'height': height,
                                                             'index': self.frame_count,
                                                             'timestamp': timestamp,
                                                             'videoName':vid_name,
                                                             'attributes': {'weather': 'clear', 'scene': scene, 'timeofday': timeofday},
                                                             'labels': []
                                                            }
                                    self.frame_trks[img_key] = []

                                    msg_count += 1
                                    self.frame_count += 1

                        ## Get Tracking Data ##
                        if self.use_bag_detections:
                            for topic, msg, t in bag.read_messages():
                                if topic == DETECTION_TOPIC:
                                    # Find corresponding frame for detections message
                                    found_frame = False
                                    for frame in sorted(self.frame_imgs.keys()):
                                        if int(msg.header.stamp.to_nsec()) > frame-3.333e7 and int(msg.header.stamp.to_nsec()) < frame+3.333e7:
                                            found_frame = True

                                            # Collect tracker_msgs
                                            detections, tracks = self.tracking_callback(msg)
                                            # Append to frame annotations table
                                            self.append_frame_trks(frame, tracks, detections)
                                            # Append to track annotations table
                                            self.append_trk_table(frame, tracks)
                                            # Debugger statement to make monitor data extraction
                                            print("FRAME TIMESTAMP:",frame)
                                            print("DETECTION TIMESTAMP:",int(msg.header.stamp.to_nsec()))
                                            print('IMG PATH:',self.frame_imgs[frame]['url'])
                                            break

                                    if not found_frame: # Try a wider time window and try to shuffle detection to appropriate frame
                                        for i, frame in enumerate(sorted(self.frame_imgs.keys())):
                                            if int(msg.header.stamp.to_nsec()) > frame-3.783e7 and int(msg.header.stamp.to_nsec()) < frame+3.783e7:
                                                found_frame = True
                                                # Collect tracks
                                                detections, tracks = self.tracking_callback(msg)
                                                # Append to buffer
                                                if len(self.frame_imgs[frame]) < 2: # Check if already assigned detections to this frame
                                                    idx = frame
                                                elif i > 0 and len(self.frame_imgs[self.frame_imgs.keys()[i-1]]) < 2: # Assign detections to the previous frame if empty
                                                    idx = self.frame_imgs.keys()[i-1]
                                                elif i < len(self.frame_imgs.keys())-1 and len(self.frame_imgs[self.frame_imgs.keys()[i+1]]) < 2: # Assign detections to the next frame if empty
                                                    idx = self.frame_imgs.keys()[i+1]
                                                else:
                                                    idx = frame
                                                # Append to frame annotations table
                                                self.append_frame_trks(idx, tracks, detections)
                                                # Append to track annotations table
                                                self.append_trk_table(idx, tracks)
                                                # Debugger statement to make monitor data extraction
                                                print("FRAME TIMESTAMP:",idx)
                                                print("DETECTION TIMESTAMP:",int(msg.header.stamp.to_nsec()))
                                                print('IMG PATH:', self.frame_imgs[idx]['url'])
                                                break
                        else:
                            anns = darknet_annotator.annotate(os.path.abspath(self.frames_dir), os.path.abspath(self.current_model_cfg_path), os.path.abspath(self.inference_model), os.path.abspath(self.current_data_cfg_path))
                            img_stamps = OrderedDict()

                            for uri, img_detections in sorted(anns):
                                img_timestamp = int(uri.split('_')[-1].replace('.jpg', '').replace('.png', '')) # Get from fpath
                                img_stamps[img_timestamp] = (uri, img_detections)


                            ann_idx = 0
                            for img_timestamp in sorted(img_stamps.keys()):
                                dets = []
                                # Get corresponding image
                                uri = img_stamps[img_timestamp][0]
                                bboxes = BoundingBoxes()
                                bboxes.header.stamp =  datetime.now()
                                bboxes.image_header = self.img_msg_header

                                for detection in img_stamps[img_timestamp][1]:
                                    # Build Detections
                                    bbox = BoundingBox()
                                    bbox.xmin, bbox.ymin = detection['box2d']['x1'], detection['box2d']['y1']
                                    bbox.xmax, bbox.ymax = detection['box2d']['x2'], detection['box2d']['y2']
                                    bbox.Class =  detection['category']
                                    bbox.class_id = 0
                                    bbox.probability = 0.99
                                    bboxes.bounding_boxes.append(bbox)

                                # Collect tracker_msgs
                                self.cv_image = cv2.imread(os.path.abspath(os.path.join(self.frames_dir, uri)))
                                detections, tracks = self.tracking_callback(bboxes)
                                # Append to frame annotations table
                                self.append_frame_trks(img_timestamp, tracks, detections)
                                # Append to track annotations table
                                self.append_trk_table(img_timestamp, tracks)
                                # Debugger statement to make monitor data extraction
                                # print("FRAME TIMESTAMP:",img_timestamp)
                                # print("DETECTION TIMESTAMP:",int(msg.header.stamp.to_nsec()))
                                # print('IMG PATH:',self.frame_imgs[img_timestamp]['url'])

                        self.save_cache(bag_file)
                    msg_count = 0
                except ROSBagException:
                    print("\n",bag_file, "Failed!  || ")
                    print(str(ROSBagException), '\n')
                    continue

            ## Generate JSON Object in BDD Format ##
            self.generate_flickering_visualizations()

            ## Generate Ground Truth Annotations in BDD Format ##
            self.generate_annotations()

            # Export to Scalabel
            self.export(bag_file)

            # Print Summary
            print("\nFrames Extracted:", len(self.flickering_frame_imgs.keys()))
            print("=================================================================\n\n")