Пример #1
0
class ObjectDetectionNode(DTROS):
    def __init__(self, node_name):

        # Initialize the DTROS parent class
        super(ObjectDetectionNode,
              self).__init__(node_name=node_name,
                             node_type=NodeType.PERCEPTION)

        # Construct publishers
        self.pub_obj_dets = rospy.Publisher("~duckie_detected",
                                            BoolStamped,
                                            queue_size=1,
                                            dt_topic_type=TopicType.PERCEPTION)

        # Construct subscribers
        self.sub_image = rospy.Subscriber("~image/compressed",
                                          CompressedImage,
                                          self.image_cb,
                                          buff_size=10000000,
                                          queue_size=1)

        self.sub_thresholds = rospy.Subscriber("~thresholds",
                                               AntiInstagramThresholds,
                                               self.thresholds_cb,
                                               queue_size=1)

        self.ai_thresholds_received = False
        self.anti_instagram_thresholds = dict()
        self.ai = AntiInstagram()
        self.bridge = CvBridge()

        model_file = rospy.get_param('~model_file', '.')
        rospack = rospkg.RosPack()
        model_file_absolute = rospack.get_path('object_detection') + model_file
        self.model_wrapper = Wrapper(model_file_absolute)
        self.initialized = True
        self.log("Initialized!")

    def thresholds_cb(self, thresh_msg):
        self.anti_instagram_thresholds["lower"] = thresh_msg.low
        self.anti_instagram_thresholds["higher"] = thresh_msg.high
        self.ai_thresholds_received = True

    def image_cb(self, image_msg):
        if not self.initialized:
            return

        # TODO to get better hz, you might want to only call your wrapper's predict function only once ever 4-5 images?
        # This way, you're not calling the model again for two practically identical images. Experiment to find a good number of skipped
        # images.

        # Decode from compressed image with OpenCV
        try:
            image = self.bridge.compressed_imgmsg_to_cv2(image_msg)
        except ValueError as e:
            self.logerr('Could not decode image: %s' % e)
            return

        # Perform color correction
        if self.ai_thresholds_received:
            image = self.ai.apply_color_balance(
                self.anti_instagram_thresholds["lower"],
                self.anti_instagram_thresholds["higher"], image)

        image = cv2.resize(image, (224, 224))
        bboxes, classes, scores = self.model_wrapper.predict(image)

        msg = BoolStamped()
        msg.header = image_msg.header
        msg.data = self.det2bool(
            bboxes[0],
            classes[0])  # [0] because our batch size given to the wrapper is 1

        self.pub_obj_dets.publish(msg)

    def det2bool(self, bboxes, classes):
        # TODO remove these debugging prints
        print(bboxes)
        print(classes)

        # This is a dummy solution, remove this next line
        return len(bboxes) > 1

        # TODO filter the predictions: the environment here is a bit different versus the data collection environment, and your model might output a bit
        # of noise. For example, you might see a bunch of predictions with x1=223.4 and x2=224, which makes
        # no sense. You should remove these.

        # TODO also filter detections which are outside of the road, or too far away from the bot. Only return True when there's a pedestrian (aka a duckie)
        # in front of the bot, which you know the bot will have to avoid. A good heuristic would be "if centroid of bounding box is in the center of the image,
        # assume duckie is in the road" and "if bouding box's area is more than X pixels, assume duckie is close to us"

        obj_det_list = []
        for i in range(len(bboxes)):
            x1, y1, x2, y2 = bboxes[i]
            label = classes[i]
Пример #2
0
class ObjectDetectionNode(DTROS):
    def __init__(self, node_name):

        # Initialize the DTROS parent class
        super(ObjectDetectionNode,
              self).__init__(node_name=node_name,
                             node_type=NodeType.PERCEPTION)

        # Construct publishers
        self.pub_obj_dets = rospy.Publisher("~duckie_detected",
                                            BoolStamped,
                                            queue_size=1,
                                            dt_topic_type=TopicType.PERCEPTION)

        # Construct subscribers
        self.sub_image = rospy.Subscriber("~image/compressed",
                                          CompressedImage,
                                          self.image_cb,
                                          buff_size=10000000,
                                          queue_size=1)

        self.sub_thresholds = rospy.Subscriber("~thresholds",
                                               AntiInstagramThresholds,
                                               self.thresholds_cb,
                                               queue_size=1)

        self.sub_camera_info = rospy.Subscriber(
            f"/{os.environ['VEHICLE_NAME']}/camera_node/camera_info",
            CameraInfo,
            self.cb_camera_info,
            queue_size=1)

        self.sub_lane_reading = rospy.Subscriber(
            f"/{os.environ['VEHICLE_NAME']}/lane_filter_node/lane_pose",
            LanePose,
            self.cbLanePoses,
            queue_size=1)

        self.initialized = False
        self.ai_thresholds_received = False
        self.anti_instagram_thresholds = dict()
        self.ai = AntiInstagram()
        self.bridge = CvBridge()

        self.ground_projector = None
        self.rectifier = None
        self.homography = self.load_extrinsics()
        self.camera_info_received = False
        self.log(str(self.homography))
        self.lane_width = rospy.get_param('~lanewidth', None)
        self.safe_distance = rospy.get_param('~safe_distance', None)

        model_file = rospy.get_param('~model_file', '.')
        rospack = rospkg.RosPack()
        model_file_absolute = rospack.get_path('object_detection') + model_file
        self.model_wrapper = Wrapper(model_file_absolute)
        self.initialized = True
        self.image_count = 0
        self.obstacle_left_lane = False
        self.obstacle_right_lane = False
        self.log("Initialized!")

    def thresholds_cb(self, thresh_msg):
        self.anti_instagram_thresholds["lower"] = thresh_msg.low
        self.anti_instagram_thresholds["higher"] = thresh_msg.high
        self.ai_thresholds_received = True

    def cb_camera_info(self, msg):
        """
        Initializes a :py:class:`image_processing.GroundProjectionGeometry` object and a
        :py:class:`image_processing.Rectify` object for image rectification

        Args:
            msg (:obj:`sensor_msgs.msg.CameraInfo`): Intrinsic properties of the camera.

        """
        if not self.camera_info_received:
            self.rectifier = Rectify(msg)
            self.ground_projector = GroundProjectionGeometry(
                im_width=msg.width,
                im_height=msg.height,
                homography=np.array(self.homography).reshape((3, 3)))
            self.im_width = msg.width
            self.im_height = msg.height

        self.camera_info_received = True

    def cbLanePoses(self, input_pose_msg):
        """Callback receiving pose messages
        Computes v and omega using PPController
        Args:
            input_pose_msg (:obj:`LanePose`): Message containing information about the current lane pose.
        """
        self.pose_msg = input_pose_msg

    def image_cb(self, image_msg):
        if not self.initialized:
            return

        # TODO to get better hz, you might want to only call your wrapper's predict function only once ever 4-5 images?
        # This way, you're not calling the model again for two practically identical images. Experiment to find a good number of skipped
        # images.

        # Decode from compressed image with OpenCV
        try:
            image = self.bridge.compressed_imgmsg_to_cv2(image_msg)
        except ValueError as e:
            self.logerr('Could not decode image: %s' % e)
            return

        # Perform color correction
        if self.ai_thresholds_received:
            image = self.ai.apply_color_balance(
                self.anti_instagram_thresholds["lower"],
                self.anti_instagram_thresholds["higher"], image)

        image = cv2.resize(image, (224, 224))
        if self.image_count != 0:
            self.image_count = np.mod(self.image_count + 1, 3)
        else:
            bboxes, classes, scores = self.model_wrapper.predict(image)
            im_boxed = self.plotWithBoundingBoxes(image, bboxes[0], classes[0],
                                                  scores[0])
            cv2.imshow('detected objects', im_boxed)
            cv2.waitKey(1)
            self.det2bool(
                bboxes[0], classes[0]
            )  # [0] because our batch size given to the wrapper is 1

        msg = BoolStamped()
        msg.header = image_msg.header
        if self.obstacle_right_lane:
            msg.data = True
            if self.obstacle_left_lane:
                pass
            else:
                ## OVERTAKING
                # msg.data = overtake
                pass

        self.pub_obj_dets.publish(msg)

    def det2bool(self, bboxes, classes):
        # TODO remove these debugging prints
        # print(bboxes)
        # print(classes)

        # This is a dummy solution, remove this next line
        # return len(bboxes) > 1

        # TODO filter the predictions: the environment here is a bit different versus the data collection environment, and your model might output a bit
        # of noise. For example, you might see a bunch of predictions with x1=223.4 and x2=224, which makes
        # no sense. You should remove these.

        # TODO also filter detections which are outside of the road, or too far away from the bot. Only return True when there's a pedestrian (aka a duckie)
        # in front of the bot, which you know the bot will have to avoid. A good heuristic would be "if centroid of bounding box is in the center of the image,
        # assume duckie is in the road" and "if bouding box's area is more than X pixels, assume duckie is close to us"

        self.obstacle_right_lane = False
        self.obstacle_left_lane = False
        obj_det_list = []
        for i in range(len(bboxes)):
            x1, y1, x2, y2 = bboxes[i]
            label = classes[i]
            if label == 1:
                if (x2 - x1 >= 2) and (y2 - y1 >= 2):
                    low_center = Point((x1 + x2) / 2, y2)
                    rect_pixel = self.rectifier.rectify_point(low_center)
                    ground_point = self.ground_projector.pixel2ground(
                        rect_pixel)
                    duckie_lane_pose = np.cos(self.pose_msg.phi) * (
                        ground_point.y + self.pose_msg.d) + np.sin(
                            self.pose_msg.phi) * ground_point.x
                    dist = np.sqrt(ground_point.x**2 + ground_point.y**2)
                    if np.abs(duckie_lane_pose
                              ) <= self.lane_width / 2:  #in our lane
                        if dist <= self.safe_distance:
                            self.obstacle_right_lane = True
                    elif np.abs(duckie_lane_pose
                                ) > self.lane_width / 2:  #in left lane
                        if dist <= self.safe_distance * 1.5:
                            self.obstacle_left_lane = True
            # TODO if label isn't a duckie, skip
            # TODO if detection is a pedestrian in front of us:
            #   return True

    def plotWithBoundingBoxes(self, seg_im, boxes, labels, scores):
        for i in range(len(labels)):
            cv2.rectangle(seg_im, (boxes[i][0], boxes[i][1]),
                          (boxes[i][2], boxes[i][3]), (255, 255, 255), 1)
            cv2.rectangle(seg_im, (boxes[i][0], boxes[i][1]),
                          (boxes[i][0] + 20, boxes[i][1] - 6), (255, 255, 255),
                          cv.FILLED)
            cv2.putText(seg_im, f"{labels[i]} : {scores[i]}",
                        (boxes[i][0], boxes[i][1]), cv.FONT_HERSHEY_COMPLEX,
                        0.5, (0, 0, 0), 1)
        return seg_im

    def load_extrinsics(self):
        """
        Loads the homography matrix from the extrinsic calibration file.

        Returns:
            :obj:`numpy array`: the loaded homography matrix

        """
        # load intrinsic calibration
        cali_file_folder = '/data/config/calibrations/camera_extrinsic/'
        cali_file = cali_file_folder + rospy.get_namespace().strip(
            "/") + ".yaml"

        # Locate calibration yaml file or use the default otherwise
        if not os.path.isfile(cali_file):
            self.log(
                "Can't find calibration file: %s.\n Using default calibration instead."
                % cali_file, 'warn')
            cali_file = (cali_file_folder + "default.yaml")

        # Shutdown if no calibration file not found
        if not os.path.isfile(cali_file):
            msg = 'Found no calibration file ... aborting'
            self.log(msg, 'err')
            rospy.signal_shutdown(msg)

        try:
            with open(cali_file, 'r') as stream:
                calib_data = yaml.load(stream)
        except yaml.YAMLError:
            msg = 'Error in parsing calibration file %s ... aborting' % cali_file
            self.log(msg, 'err')
            rospy.signal_shutdown(msg)

        return calib_data['homography']
class ObjectDetectionNode(DTROS):

    def __init__(self, node_name):

        # Initialize the DTROS parent class
        super(ObjectDetectionNode, self).__init__(
            node_name=node_name,
            node_type=NodeType.PERCEPTION
        )
        self.initialized = False
        self.log("Initializing!")


        # Construct publishers
        self.pub_obj_dets = rospy.Publisher(
            "~duckie_detected",
            BoolStamped,
            queue_size=1,
            dt_topic_type=TopicType.PERCEPTION
        )

        self. pub_detections_image = rospy.Publisher(
            "~object_detections_img", Image, queue_size=1, dt_topic_type=TopicType.DEBUG
        )

        # Construct subscribers
        self.sub_image = rospy.Subscriber(
            "~image/compressed",
            CompressedImage,
            self.image_cb,
            buff_size=10000000,
            queue_size=1
        )
        
        self.sub_thresholds = rospy.Subscriber(
            "~thresholds",
            AntiInstagramThresholds,
            self.thresholds_cb,
            queue_size=1
        )
        
        self.ai_thresholds_received = False
        self.anti_instagram_thresholds=dict()
        self.ai = AntiInstagram()
        self.bridge = CvBridge()

        model_file = rospy.get_param('~model_file','.')
        self.veh = rospy.get_namespace().strip("/")
        aido_eval = rospy.get_param("~AIDO_eval", False)
        self.log(f"AIDO EVAL VAR: {aido_eval}")
        self.log("Starting model loading!")
        self._debug = rospy.get_param("~debug", False)
        self.model_wrapper = Wrapper(aido_eval)
        self.log("Finished model loading!")
        self.frame_id = 0
        self.initialized = True
        self.log("Initialized!")
    
    def thresholds_cb(self, thresh_msg):
        self.anti_instagram_thresholds["lower"] = thresh_msg.low
        self.anti_instagram_thresholds["higher"] = thresh_msg.high
        self.ai_thresholds_received = True

    def image_cb(self, image_msg):
        if not self.initialized:
            return

        if self.frame_id != 0:
            return
        self.frame_id += 1
        self.frame_id = self.frame_id % (1 + NUMBER_FRAMES_SKIPPED())

        # Decode from compressed image with OpenCV
        try:
            image = self.bridge.compressed_imgmsg_to_cv2(image_msg)
        except ValueError as e:
            self.logerr('Could not decode image: %s' % e)
            return
        
        # Perform color correction
        if self.ai_thresholds_received:
            image = self.ai.apply_color_balance(
                self.anti_instagram_thresholds["lower"],
                self.anti_instagram_thresholds["higher"],
                image
            )
        
        image = cv2.resize(image, (416,416))
        bboxes, classes, scores = self.model_wrapper.predict(image)
        
        msg = BoolStamped()
        msg.header = image_msg.header
        msg.data = self.det2bool(bboxes, classes, scores)
        
        self.pub_obj_dets.publish(msg)

        if self._debug:
            colors = {0: (0, 255, 255), 1: (0, 165, 255), 2: (0, 250, 0), 3: (0, 0, 255)}
            names = {0: "duckie", 1: "cone", 2: "truck", 3: "bus"}
            font = cv2.FONT_HERSHEY_SIMPLEX
            for clas, box in zip(classes, bboxes):
                pt1 = np.array([int(box[0]), int(box[1])])
                pt2 = np.array([int(box[2]), int(box[3])])
                pt1 = tuple(pt1)
                pt2 = tuple(pt2)
                color = colors[clas.item()]
                name = names[clas.item()]
                image = cv2.rectangle(image, pt1, pt2, color, 2)
                text_location = (pt1[0], min(416, pt1[1]+20))
                image = cv2.putText(image, name, text_location, font, 1, color, thickness=3)
            obj_det_img = self.bridge.cv2_to_imgmsg(image, encoding="bgr8")
            self.pub_detections_image.publish(obj_det_img)


    def det2bool(self, bboxes, classes, scores):

        box_ids = np.array(list(map(filter_by_bboxes, bboxes))).nonzero()[0]
        cla_ids = np.array(list(map(filter_by_classes, classes))).nonzero()[0]
        sco_ids = np.array(list(map(filter_by_scores, scores))).nonzero()[0]


        box_cla_ids = set(list(box_ids)).intersection(set(list(cla_ids)))
        box_cla_sco_ids = set(list(sco_ids)).intersection(set(list(box_cla_ids)))


        if len(box_cla_sco_ids) > 0:
            return True
        else:
            return False
class ObjectDetectionNode(DTROS):
    def __init__(self, node_name):

        # Initialize the DTROS parent class
        super(ObjectDetectionNode,
              self).__init__(node_name=node_name,
                             node_type=NodeType.PERCEPTION)

        # Construct publishers
        self.pub_obj_dets = rospy.Publisher("~duckie_detected",
                                            BoolStamped,
                                            queue_size=1,
                                            dt_topic_type=TopicType.PERCEPTION)

        # Construct subscribers
        self.sub_image = rospy.Subscriber("~image/compressed",
                                          CompressedImage,
                                          self.image_cb_det,
                                          buff_size=10000000,
                                          queue_size=1)

        self.sub_thresholds = rospy.Subscriber("~thresholds",
                                               AntiInstagramThresholds,
                                               self.thresholds_cb,
                                               queue_size=1)

        # self.pub_seglist_filtered = rospy.Publisher("~seglist_filtered",
        #                                             SegmentList,
        #                                             queue_size=1,
        #                                             dt_topic_type=TopicType.DEBUG)

        self.pub_segmented_img = rospy.Publisher(
            "~debug/segmented_image/compressed",
            CompressedImage,
            queue_size=1,
            dt_topic_type=TopicType.DEBUG)

        self.ai_thresholds_received = False
        self.anti_instagram_thresholds = dict()
        self.ai = AntiInstagram()
        self.bridge = CvBridge()

        model_file = rospy.get_param('~model_file', '.')
        rospack = rospkg.RosPack()
        model_file_absolute = rospack.get_path('object_detection') + model_file
        self.model_wrapper = Wrapper(model_file_absolute)
        self.homography = self.load_extrinsics()
        homography = np.array(self.homography).reshape((3, 3))
        self.bridge = CvBridge()
        self.gpg = GroundProjectionGeometry(160, 120, homography)
        # self.gpg = GroundProjectionGeometry(320, 240, homography)
        self.initialized = True
        self.log("Initialized!")

    def thresholds_cb(self, thresh_msg):
        self.anti_instagram_thresholds["lower"] = thresh_msg.low
        self.anti_instagram_thresholds["higher"] = thresh_msg.high
        self.ai_thresholds_received = True

    def image_cb_det(self, image_msg):
        if not self.initialized:
            return

        # Decode from compressed image with OpenCV
        try:
            image = self.bridge.compressed_imgmsg_to_cv2(image_msg)
        except ValueError as e:
            self.logerr('Could not decode image: %s' % e)
            return

        # Perform color correction
        if self.ai_thresholds_received:
            image = self.ai.apply_color_balance(
                self.anti_instagram_thresholds["lower"],
                self.anti_instagram_thresholds["higher"], image)

        img_reg = cv2.resize(image, (224, 224))
        img_rgb = cv2.cvtColor(img_reg, cv2.COLOR_BGR2RGB)
        boxes, classes, scores = self.model_wrapper.predict(img_rgb)
        boxes, classes, scores = boxes[0], classes[0], scores[0]
        if type(boxes) != type(None):
            img_w_boxes = add_boxes(img_reg, boxes, classes, scores)
        else:
            img_w_boxes = img_reg
        detection_img = self.bridge.cv2_to_compressed_imgmsg(img_w_boxes)
        detection_img.header.stamp = image_msg.header.stamp
        self.pub_segmented_img.publish(detection_img)

        msg = BoolStamped()
        msg.header = image_msg.header
        msg.data = self.det2bool(
            boxes,
            classes)  # [0] because our batch size given to the wrapper is 1

        self.pub_obj_dets.publish(msg)

        #
        # msg = BoolStamped()
        # msg.header = image_msg.header
        # if len(duckie_segments) == 0:
        #     # No duckie detection at all!
        #     msg.data = False
        # else:
        #     msg.data = self.det2bool(duckie_segments, min_num_seg=3, x_lim=0.2, y_lim=0.05)
        #     if msg.data:
        #         print("A duckie is facing the bot, let's stop and wait for it to cross")
        # self.pub_obj_dets.publish(msg)

    def det2bool(self, boxes, classes):
        if type(boxes) != type(None):
            for i in range(len(boxes)):
                if classes[
                        i] != 1:  #everything except duckie is not important for now
                    continue
                else:
                    x1, y1, x2, y2 = boxes[i]
                    centroid_x = 0.5 * (x1 + x2)
                    centroid_y = 0.5 * (y1 + y2)
                    if 224 >= centroid_x >= 0.5 * 224:  # in the bottow 50% of the image
                        if 0.75 * 224 >= centroid_y >= 0.25 * 224:  #in the middle third of the image (horizontal)
                            if abs((x2 - x1) * (y2 - y1)) >= 700:
                                print("duckie detected")
                                return True

        return False

    def load_extrinsics(self):
        """
        Loads the homography matrix from the extrinsic calibration file.
        Returns:
            :obj:`numpy array`: the loaded homography matrix
        """
        # load intrinsic calibration
        cali_file_folder = '/data/config/calibrations/camera_extrinsic/'
        cali_file = cali_file_folder + rospy.get_namespace().strip(
            "/") + ".yaml"

        # Locate calibration yaml file or use the default otherwise
        if not os.path.isfile(cali_file):
            self.log(
                "Can't find calibration file: %s.\n Using default calibration instead."
                % cali_file, 'warn')
            cali_file = (cali_file_folder + "default.yaml")

        # Shutdown if no calibration file not found
        if not os.path.isfile(cali_file):
            msg = 'Found no calibration file ... aborting'
            self.log(msg, 'err')
            rospy.signal_shutdown(msg)

        try:
            with open(cali_file, 'r') as stream:
                calib_data = yaml.load(stream)
        except yaml.YAMLError:
            msg = 'Error in parsing calibration file %s ... aborting' % cali_file
            self.log(msg, 'err')
            rospy.signal_shutdown(msg)

        return calib_data['homography']
class ObjectDetectionNode(DTROS):
    def __init__(self, node_name):

        # Initialize the DTROS parent class
        super(ObjectDetectionNode,
              self).__init__(node_name=node_name,
                             node_type=NodeType.PERCEPTION)
        self.initialized = False

        # Construct publishers
        self.pub_obj_dets = rospy.Publisher("~duckie_detected",
                                            BoolStamped,
                                            queue_size=1,
                                            dt_topic_type=TopicType.PERCEPTION)

        # Construct subscribers
        self.sub_image = rospy.Subscriber("~image/compressed",
                                          CompressedImage,
                                          self.image_cb,
                                          buff_size=10000000,
                                          queue_size=1)

        self.sub_thresholds = rospy.Subscriber("~thresholds",
                                               AntiInstagramThresholds,
                                               self.thresholds_cb,
                                               queue_size=1)

        self.ai_thresholds_received = False
        self.anti_instagram_thresholds = dict()
        self.ai = AntiInstagram()
        self.bridge = CvBridge()

        model_file = rospy.get_param('~model_file', '.')
        rospack = rospkg.RosPack()
        model_file_absolute = rospack.get_path('object_detection') + model_file
        self.model_wrapper = Wrapper(model_file_absolute)
        self.frame_id = 0
        self.initialized = True
        self.log("Initialized!")

    def thresholds_cb(self, thresh_msg):
        self.anti_instagram_thresholds["lower"] = thresh_msg.low
        self.anti_instagram_thresholds["higher"] = thresh_msg.high
        self.ai_thresholds_received = True

    def image_cb(self, image_msg):
        if not self.initialized:
            return

        if self.frame_id != 0:
            return
        self.frame_id += 1
        from integration import NUMBER_FRAMES_SKIPPED
        self.frame_id = self.frame_id % (1 + NUMBER_FRAMES_SKIPPED())

        # Decode from compressed image with OpenCV
        try:
            image = self.bridge.compressed_imgmsg_to_cv2(image_msg)
        except ValueError as e:
            self.logerr('Could not decode image: %s' % e)
            return

        # Perform color correction
        if self.ai_thresholds_received:
            image = self.ai.apply_color_balance(
                self.anti_instagram_thresholds["lower"],
                self.anti_instagram_thresholds["higher"], image)

        image = cv2.resize(image, (416, 416))
        bboxes, classes, scores = self.model_wrapper.predict(image)

        msg = BoolStamped()
        msg.header = image_msg.header
        msg.data = self.det2bool(bboxes, classes, scores)

        self.pub_obj_dets.publish(msg)

    def det2bool(self, bboxes, classes, scores):
        print(f"Before filtering: {len(bboxes)} detections")

        from integration import filter_by_classes
        from integration import filter_by_bboxes
        from integration import filter_by_scores

        box_ids = np.array(list(map(filter_by_bboxes, bboxes))).nonzero()
        cla_ids = np.array(list(map(filter_by_classes, classes))).nonzero()
        sco_ids = np.array(list(map(filter_by_scores, scores))).nonzero()

        box_cla_ids = set(box_ids).intersection(set(cla_ids))
        box_cla_sco_ids = set(sco_ids).intersection(box_cla_ids)

        print(f"After filtering: {len(box_cla_sco_ids)} detections")

        if len(box_cla_sco_ids) > 0:
            return True
Пример #6
0
class ObjectDetectionNode(DTROS):
    def __init__(self, node_name):

        # Initialize the DTROS parent class
        super(ObjectDetectionNode,
              self).__init__(node_name=node_name,
                             node_type=NodeType.PERCEPTION)

        # Construct publishers
        self.pub_obj_dets = rospy.Publisher("~duckie_detected",
                                            BoolStamped,
                                            queue_size=1,
                                            dt_topic_type=TopicType.PERCEPTION)

        # Construct subscribers
        self.sub_image = rospy.Subscriber("~image/compressed",
                                          CompressedImage,
                                          self.image_cb,
                                          buff_size=10000000,
                                          queue_size=1)

        self.sub_thresholds = rospy.Subscriber("~thresholds",
                                               AntiInstagramThresholds,
                                               self.thresholds_cb,
                                               queue_size=1)

        self.ai_thresholds_received = False
        self.anti_instagram_thresholds = dict()
        self.ai = AntiInstagram()
        self.bridge = CvBridge()

        model_file = rospy.get_param('~model_file', '.')
        rospack = rospkg.RosPack()
        model_file_absolute = rospack.get_path('object_detection') + model_file
        self.model_wrapper = Wrapper(model_file_absolute)
        self.initialized = True
        self.log("Initialized!")

    def thresholds_cb(self, thresh_msg):
        self.anti_instagram_thresholds["lower"] = thresh_msg.low
        self.anti_instagram_thresholds["higher"] = thresh_msg.high
        self.ai_thresholds_received = True

    def image_cb(self, image_msg):
        if not self.initialized:
            return

        # TODO to get better hz, you might want to only call your wrapper's predict function only once ever 4-5 images?
        # This way, you're not calling the model again for two practically identical images. Experiment to find a good number of skipped
        # images.

        # Decode from compressed image with OpenCV
        try:
            image = self.bridge.compressed_imgmsg_to_cv2(image_msg)
        except ValueError as e:
            self.logerr('Could not decode image: %s' % e)
            return

        # Perform color correction
        if self.ai_thresholds_received:
            image = self.ai.apply_color_balance(
                self.anti_instagram_thresholds["lower"],
                self.anti_instagram_thresholds["higher"], image)

        image = cv2.resize(image, (224, 224))
        bboxes, classes, scores = self.model_wrapper.predict(image)

        msg = BoolStamped()
        msg.header = image_msg.header
        msg.data = self.det2bool(
            bboxes[0],
            classes[0])  # [0] because our batch size given to the wrapper is 1

        self.pub_obj_dets.publish(msg)

    def midpoint(self, p1, p2):
        return Point((p1.x + p2.x) / 2, (p1.y + p2.y) / 2)

    def det2bool(self, bboxes, classes):

        middle_bounds_x = (80, 160)
        middle_bounds_y = (100, 224)
        for i in range(len(bboxes)):
            if abs(bboxes[i][0] - bboxes[i][2]) < 2 or abs(bboxes[i][1] -
                                                           bboxes[i][3]) < 2:
                print("SKIP")
                continue
            if classes[i] == 1:
                lower = Point(bboxes[i][0], bboxes[i][1])
                upper = Point(bboxes[i][2], bboxes[i][3])
                middle = self.midpoint(lower, upper)
                print("MIDDLE")
                print(f"{middle.x},{middle.y}")
                if middle.x > middle_bounds_x[
                        0] and middle.x < middle_bounds_x[1]:
                    if middle.y > middle_bounds_y[
                            0] and middle.y < middle_bounds_y[1]:
                        print("DUCKIE IN FRONT")
                    return True
        return False