示例#1
0
class TestPlugin(Plugin):
    def __init__(self, context):
        """
        TestPlugin class to evaluate the image_recognition_msgs interfaces
        :param context: QT context, aka parent
        """
        super(TestPlugin, self).__init__(context)

        # Widget setup
        self.setObjectName('Test Plugin')

        self._widget = QWidget()
        context.add_widget(self._widget)

        # Layout and attach to widget
        layout = QVBoxLayout()
        self._widget.setLayout(layout)

        self._image_widget = ImageWidget(self._widget,
                                         self.image_roi_callback,
                                         clear_on_click=True)
        layout.addWidget(self._image_widget)

        # Input field
        grid_layout = QGridLayout()
        layout.addLayout(grid_layout)

        self._info = QLineEdit()
        self._info.setDisabled(True)
        self._info.setText(
            "Draw a rectangle on the screen to perform recognition of that ROI"
        )
        layout.addWidget(self._info)

        # Bridge for opencv conversion
        self.bridge = CvBridge()

        # Set subscriber and service to None
        self._sub = None
        self._srv = None

    def recognize_srv_call(self, roi_image):
        """
        Method that calls the Recognize.srv
        :param roi_image: Selected roi_image by the user
        """
        try:
            result = self._srv(
                image=self.bridge.cv2_to_imgmsg(roi_image, "bgr8"))
        except Exception as e:
            warning_dialog("Service Exception", str(e))
            return

        print result

        for r in result.recognitions:
            text_array = []
            best = CategoryProbability(
                label="unknown",
                probability=r.categorical_distribution.unknown_probability)
            for p in r.categorical_distribution.probabilities:
                text_array.append("%s: %.2f" % (p.label, p.probability))
                if p.probability > best.probability:
                    best = p

            self._image_widget.add_detection(r.roi.x_offset, r.roi.y_offset,
                                             r.roi.width, r.roi.height,
                                             best.label)

            if text_array:
                option_dialog(
                    "Classification results (Unknown probability=%.2f)" %
                    r.categorical_distribution.unknown_probability,
                    text_array)  # Show all results in a dropdown

    def get_face_properties_srv_call(self, roi_image):
        """
        Method that calls the GetFaceProperties.srv
        :param roi_image: Selected roi_image by the user
        """
        try:
            result = self._srv(face_image_array=[
                self.bridge.cv2_to_imgmsg(roi_image, "bgr8")
            ])
        except Exception as e:
            warning_dialog("Service Exception", str(e))
            return

        msg = ""
        for properties in result.properties_array:
            msg += "- FaceProperties(gender=%s, age=%s, glasses=%s, mood=%s)" % \
                   ("male" if properties.gender == FaceProperties.MALE else "female", properties.age,
            "false" if properties.glasses == 0 else "true", properties.mood)

        info_dialog("Face Properties array", msg)

    def image_roi_callback(self, roi_image):
        """
        Callback triggered when the user has drawn an ROI on the image
        :param roi_image: The opencv image in the ROI
        """
        if self._srv is None:
            warning_dialog(
                "No service specified!",
                "Please first specify a service via the options button (top-right gear wheel)"
            )
            return

        if self._srv.service_class == Recognize:
            self.recognize_srv_call(roi_image)
        elif self._srv.service_class == GetFaceProperties:
            self.get_face_properties_srv_call(roi_image)
        else:
            warning_dialog("Unknown service class", "Service class is unkown!")

    def _image_callback(self, msg):
        """
        Sensor_msgs/Image callback
        :param msg: The image message
        """
        try:
            cv_image = self.bridge.imgmsg_to_cv2(msg, "bgr8")
        except CvBridgeError as e:
            rospy.logerr(e)
            return

        self._image_widget.set_image(cv_image)

    def trigger_configuration(self):
        """
        Callback when the configuration button is clicked
        """
        topic_name, ok = QInputDialog.getItem(
            self._widget, "Select topic name", "Topic name",
            rostopic.find_by_type('sensor_msgs/Image'))
        if ok:
            self._create_subscriber(topic_name)

        available_rosservices = []
        for s in rosservice.get_service_list():
            try:
                if rosservice.get_service_type(s) in _SUPPORTED_SERVICES:
                    available_rosservices.append(s)
            except:
                pass

        srv_name, ok = QInputDialog.getItem(self._widget,
                                            "Select service name",
                                            "Service name",
                                            available_rosservices)
        if ok:
            self._create_service_client(srv_name)

    def _create_subscriber(self, topic_name):
        """
        Method that creates a subscriber to a sensor_msgs/Image topic
        :param topic_name: The topic_name
        """
        if self._sub:
            self._sub.unregister()
        self._sub = rospy.Subscriber(topic_name, Image, self._image_callback)
        rospy.loginfo("Listening to %s -- spinning .." % self._sub.name)
        self._widget.setWindowTitle("Test plugin, listening to (%s)" %
                                    self._sub.name)

    def _create_service_client(self, srv_name):
        """
        Method that creates a client service proxy to call either the GetFaceProperties.srv or the Recognize.srv
        :param srv_name:
        """
        if self._srv:
            self._srv.close()

        if srv_name in rosservice.get_service_list():
            rospy.loginfo("Creating proxy for service '%s'" % srv_name)
            self._srv = rospy.ServiceProxy(
                srv_name, rosservice.get_service_class_by_name(srv_name))

    def shutdown_plugin(self):
        """
        Callback function when shutdown is requested
        """
        pass

    def save_settings(self, plugin_settings, instance_settings):
        """
        Callback function on shutdown to store the local plugin variables
        :param plugin_settings: Plugin settings
        :param instance_settings: Settings of this instance
        """
        if self._sub:
            instance_settings.set_value("topic_name", self._sub.name)

    def restore_settings(self, plugin_settings, instance_settings):
        """
        Callback function fired on load of the plugin that allows to restore saved variables
        :param plugin_settings: Plugin settings
        :param instance_settings: Settings of this instance
        """
        self._create_subscriber(
            str(instance_settings.value("topic_name", "/usb_cam/image_raw")))
        self._create_service_client(
            str(
                instance_settings.value("service_name",
                                        "/image_recognition/my_service")))
示例#2
0
class ManualPlugin(Plugin):

    def __init__(self, context):
        """
        ManualPlugin class that performs a manual recognition based on a request
        :param context: QT context, aka parent
        """
        super(ManualPlugin, self).__init__(context)

        # Widget setup
        self.setObjectName('Manual Plugin')

        self._widget = QWidget()
        context.add_widget(self._widget)
        
        # Layout and attach to widget
        layout = QVBoxLayout()  
        self._widget.setLayout(layout)

        self._image_widget = ImageWidget(self._widget, self.image_roi_callback)
        layout.addWidget(self._image_widget)

        # Input field
        grid_layout = QGridLayout()
        layout.addLayout(grid_layout)

        self._labels_edit = QLineEdit()
        self._labels_edit.setDisabled(True)
        grid_layout.addWidget(self._labels_edit, 2, 2)

        self._edit_labels_button = QPushButton("Edit labels")
        self._edit_labels_button.clicked.connect(self._get_labels)
        grid_layout.addWidget(self._edit_labels_button, 2, 1)

        self._done_recognizing_button = QPushButton("Done recognizing..")
        self._done_recognizing_button.clicked.connect(self._done_recognizing)
        self._done_recognizing_button.setDisabled(True)
        grid_layout.addWidget(self._done_recognizing_button, 3, 2)

        # Bridge for opencv conversion
        self.bridge = CvBridge()

        # Set service to None
        self._srv = None
        self._srv_name = None

        self._response = RecognizeResponse()
        self._recognizing = False

    def _get_labels(self):
        """
        Gets and sets the labels
        """
        text, ok = QInputDialog.getText(self._widget, 'Text Input Dialog',
                                        'Type labels semicolon separated, e.g. banana;apple:',
                                        QLineEdit.Normal, ";".join(self.labels))
        if ok:
            # Sanitize to alphanumeric, exclude spaces
            labels = set([_sanitize(label) for label in str(text).split(";") if _sanitize(label)])
            self._set_labels(labels)

    def _set_labels(self, labels):
        """
        Sets the labels
        :param labels: label string array
        """
        if not labels:
            labels = []

        self.labels = labels
        self._labels_edit.setText("%s" % labels)

    def _done_recognizing(self):
        self._image_widget.clear()
        self._recognizing = False

    def recognize_srv_callback(self, req):
        """
        Method callback for the Recognize.srv
        :param req: The service request
        """
        self._response.recognitions = []
        self._recognizing = True

        try:
            cv_image = self.bridge.imgmsg_to_cv2(req.image, "bgr8")
        except CvBridgeError as e:
            rospy.logerr(e)

        self._image_widget.set_image(cv_image)
        self._done_recognizing_button.setDisabled(False)

        timeout = 60.0  # Maximum of 60 seconds
        future = rospy.Time.now() + rospy.Duration(timeout)
        rospy.loginfo("Waiting for manual recognition, maximum of %d seconds", timeout)
        while not rospy.is_shutdown() and self._recognizing:
            if rospy.Time.now() > future:
                raise rospy.ServiceException("Timeout of %d seconds exceeded .." % timeout)
            rospy.sleep(rospy.Duration(0.1))

        self._done_recognizing_button.setDisabled(True)

        return self._response

    def image_roi_callback(self, roi_image):
        """
        Callback triggered when the user has drawn an ROI on the image
        :param roi_image: The opencv image in the ROI
        """
        if not self.labels:
            warning_dialog("No labels specified!", "Please first specify some labels using the 'Edit labels' button")
            return

        height, width = roi_image.shape[:2]

        option = option_dialog("Label", self.labels)
        if option:
            self._image_widget.add_detection(0, 0, width, height, option)
            self._stage_recognition(self._image_widget.get_roi(), option)

    def _stage_recognition(self, roi, label):
        """
        Stage a manual recognition
        :param roi: ROI
        :param label: The label
        """
        x, y, width, height = roi
        r = Recognition(roi=RegionOfInterest(x_offset=x, y_offset=y, width=width, height=height))
        r.categorical_distribution.probabilities = [CategoryProbability(label=label, probability=1.0)]
        r.categorical_distribution.unknown_probability = 0.0

        self._response.recognitions.append(r)

    def trigger_configuration(self):
        """
        Callback when the configuration button is clicked
        """

        srv_name, ok = QInputDialog.getText(self._widget, "Select service name", "Service name")
        if ok:
            self._create_service_server(srv_name)

    def _create_service_server(self, srv_name):
        """
        Method that creates a service server for a Recognize.srv
        :param srv_name:
        """
        if self._srv:
            self._srv.shutdown()

        if srv_name:
            rospy.loginfo("Creating service '%s'" % srv_name)
            self._srv_name = srv_name
            self._srv = rospy.Service(srv_name, Recognize, self.recognize_srv_callback)

    def shutdown_plugin(self):
        """
        Callback function when shutdown is requested
        """
        pass

    def save_settings(self, plugin_settings, instance_settings):
        """
        Callback function on shutdown to store the local plugin variables
        :param plugin_settings: Plugin settings
        :param instance_settings: Settings of this instance
        """
        instance_settings.set_value("labels", self.labels)
        if self._srv:
            instance_settings.set_value("srv_name", self._srv_name)

    def restore_settings(self, plugin_settings, instance_settings):
        """
        Callback function fired on load of the plugin that allows to restore saved variables
        :param plugin_settings: Plugin settings
        :param instance_settings: Settings of this instance
        """
        labels = None
        try:
            labels = instance_settings.value("labels")
        except:
            pass
        self._set_labels(labels)
        self._create_service_server(str(instance_settings.value("srv_name", "/my_recognition_service")))
示例#3
0
class LabelPlugin(Plugin):
    def __init__(self, context):
        super(LabelPlugin, self).__init__(context)

        # Widget setup
        self.setObjectName('Label Plugin')

        self._widget = QWidget()
        context.add_widget(self._widget)

        # Layout and attach to widget
        layout = QVBoxLayout()
        self._widget.setLayout(layout)

        self._image_widget = ImageWidget(self._widget, self.image_roi_callback)
        layout.addWidget(self._image_widget)

        # Input field
        grid_layout = QGridLayout()
        layout.addLayout(grid_layout)

        self._edit_path_button = QPushButton("Edit path")
        self._edit_path_button.clicked.connect(self._get_output_directory)
        grid_layout.addWidget(self._edit_path_button, 1, 1)

        self._output_path_edit = QLineEdit()
        self._output_path_edit.setDisabled(True)
        grid_layout.addWidget(self._output_path_edit, 1, 2)

        self._labels_edit = QLineEdit()
        self._labels_edit.setDisabled(True)
        grid_layout.addWidget(self._labels_edit, 2, 2)

        self._edit_labels_button = QPushButton("Edit labels")
        self._edit_labels_button.clicked.connect(self._get_labels)
        grid_layout.addWidget(self._edit_labels_button, 2, 1)

        self._save_button = QPushButton("Save another one")
        self._save_button.clicked.connect(self.store_image)
        grid_layout.addWidget(self._save_button, 2, 3)

        # Bridge for opencv conversion
        self.bridge = CvBridge()

        # Set subscriber to None
        self._sub = None

        self.labels = []
        self.roi_image = None
        self.label = ""
        self.output_directory = ""

    def image_roi_callback(self, roi_image):
        if not self.labels:
            warning_dialog(
                "No labels specified!",
                "Please first specify some labels using the 'Edit labels' button"
            )
            return

        self.roi_image = roi_image

        option = option_dialog("Label", self.labels)
        if option:
            self.label = option
            self._image_widget.set_text(option)

        self.store_image()

    def store_image(self):
        if not None in [self.roi_image, self.label, self.output_directory]:
            _write_image_to_file(self.output_directory, self.roi_image,
                                 self.label)

    def _get_output_directory(self):
        self._set_output_directory(
            QFileDialog.getExistingDirectory(self._widget,
                                             "Select output directory"))

    def _set_output_directory(self, path):
        if not path:
            path = "/tmp"

        self.output_directory = path
        self._output_path_edit.setText("Saving images to %s" % path)

    def _get_labels(self):
        text, ok = QInputDialog.getText(
            self._widget, 'Text Input Dialog',
            'Type labels semicolon separated, e.g. banana;apple:',
            QLineEdit.Normal, ";".join(self.labels))
        if ok:
            labels = set([
                _sanitize(label) for label in str(text).split(";")
                if _sanitize(label)
            ])  # Sanitize to alphanumeric, exclude spaces
            self._set_labels(labels)

    def _set_labels(self, labels):
        if not labels:
            labels = []

        self.labels = labels
        self._labels_edit.setText("%s" % labels)

    def _image_callback(self, msg):
        try:
            cv_image = self.bridge.imgmsg_to_cv2(msg, "bgr8")
        except CvBridgeError as e:
            rospy.logerr(e)

        self._image_widget.set_image(cv_image)

    def trigger_configuration(self):
        topic_name, ok = QInputDialog.getItem(
            self._widget, "Select topic name", "Topic name",
            rostopic.find_by_type('sensor_msgs/Image'))
        if ok:
            self._create_subscriber(topic_name)

    def _create_subscriber(self, topic_name):
        if self._sub:
            self._sub.unregister()
        self._sub = rospy.Subscriber(topic_name, Image, self._image_callback)
        rospy.loginfo("Listening to %s -- spinning .." % self._sub.name)
        self._widget.setWindowTitle("Label plugin, listening to (%s)" %
                                    self._sub.name)

    def shutdown_plugin(self):
        pass

    def save_settings(self, plugin_settings, instance_settings):
        instance_settings.set_value("output_directory", self.output_directory)
        instance_settings.set_value("labels", self.labels)
        if self._sub:
            instance_settings.set_value("topic_name", self._sub.name)

    def restore_settings(self, plugin_settings, instance_settings):
        path = None
        try:
            path = instance_settings.value("output_directory")
        except:
            pass
        self._set_output_directory(path)

        labels = None
        try:
            labels = instance_settings.value("labels")
        except:
            pass
        self._set_labels(labels)

        self._create_subscriber(
            str(instance_settings.value("topic_name", "/usb_cam/image_raw")))
class TestPlugin(Plugin):
    def __init__(self, context):
        super(TestPlugin, self).__init__(context)

        # Widget setup
        self.setObjectName('Test Plugin')

        self._widget = QWidget()
        context.add_widget(self._widget)

        # Layout and attach to widget
        layout = QVBoxLayout()
        self._widget.setLayout(layout)

        self._image_widget = ImageWidget(self._widget, self.image_roi_callback)
        layout.addWidget(self._image_widget)

        # Input field
        grid_layout = QGridLayout()
        layout.addLayout(grid_layout)

        self._info = QLineEdit()
        self._info.setDisabled(True)
        self._info.setText(
            "Draw a rectangle on the screen to perform object recognition of that ROI"
        )
        layout.addWidget(self._info)

        # Bridge for opencv conversion
        self.bridge = CvBridge()

        # Set subscriber and service to None
        self._sub = None
        self._srv = None

    def image_roi_callback(self, roi_image):
        if self._srv is None:
            warning_dialog(
                "No service specified!",
                "Please first specify a service via the options button (top-right gear wheel)"
            )
            return

        try:
            result = self._srv(
                image=self.bridge.cv2_to_imgmsg(roi_image, "bgr8"))
        except Exception as e:
            warning_dialog("Service Exception", str(e))
            return

        text_array = [
            "%s: %.2f" % (r.label, r.probability) for r in result.recognitions
        ]

        if text_array:
            self._image_widget.set_text(
                text_array[0])  # Show first option in the image
            option_dialog("Classification results",
                          text_array)  # Show all results in a dropdown

    def _image_callback(self, msg):
        try:
            cv_image = self.bridge.imgmsg_to_cv2(msg, "bgr8")
        except CvBridgeError as e:
            rospy.logerr(e)

        self._image_widget.set_image(cv_image)

    def trigger_configuration(self):
        topic_name, ok = QInputDialog.getItem(
            self._widget, "Select topic name", "Topic name",
            rostopic.find_by_type('sensor_msgs/Image'))
        if ok:
            self._create_subscriber(topic_name)

        available_rosservices = []
        for s in rosservice.get_service_list():
            try:
                if rosservice.get_service_type(
                        s) == "object_recognition_srvs/Recognize":
                    available_rosservices.append(s)
            except:
                pass

        srv_name, ok = QInputDialog.getItem(self._widget,
                                            "Select service name",
                                            "Service name",
                                            available_rosservices)
        if ok:
            self._create_service_client(srv_name)

    def _create_subscriber(self, topic_name):
        if self._sub:
            self._sub.unregister()
        self._sub = rospy.Subscriber(topic_name, Image, self._image_callback)
        rospy.loginfo("Listening to %s -- spinning .." % self._sub.name)
        self._widget.setWindowTitle("Test plugin, listening to (%s)" %
                                    self._sub.name)

    def _create_service_client(self, srv_name):
        if self._srv:
            self._srv.close()
        rospy.loginfo("Creating proxy for service '%s'" % srv_name)
        self._srv = rospy.ServiceProxy(srv_name, Recognize)

    def shutdown_plugin(self):
        pass

    def save_settings(self, plugin_settings, instance_settings):
        if self._sub:
            instance_settings.set_value("topic_name", self._sub.name)

    def restore_settings(self, plugin_settings, instance_settings):
        self._create_subscriber(
            str(instance_settings.value("topic_name", "/usb_cam/image_raw")))
        self._create_service_client(
            str(
                instance_settings.value("service_name",
                                        "/object_recognition/blaat")))
class AnnotationPlugin(Plugin):
    def __init__(self, context):
        """
        Annotation plugin to create data sets or test the Annotate.srv service
        :param context: Parent QT widget
        """
        super(AnnotationPlugin, self).__init__(context)

        # Widget setup
        self.setObjectName('Label Plugin')

        self._widget = QWidget()
        context.add_widget(self._widget)

        # Layout and attach to widget
        layout = QVBoxLayout()
        self._widget.setLayout(layout)

        self._image_widget = ImageWidget(self._widget,
                                         self.image_roi_callback,
                                         clear_on_click=True)
        layout.addWidget(self._image_widget)

        # Input field
        grid_layout = QGridLayout()
        layout.addLayout(grid_layout)

        self._edit_path_button = QPushButton("Edit path")
        self._edit_path_button.clicked.connect(self._get_output_directory)
        grid_layout.addWidget(self._edit_path_button, 1, 1)

        self._output_path_edit = QLineEdit()
        self._output_path_edit.setDisabled(True)
        grid_layout.addWidget(self._output_path_edit, 1, 2)

        self._labels_edit = QLineEdit()
        self._labels_edit.setDisabled(True)
        grid_layout.addWidget(self._labels_edit, 2, 2)

        self._edit_labels_button = QPushButton("Edit labels")
        self._edit_labels_button.clicked.connect(self._get_labels)
        grid_layout.addWidget(self._edit_labels_button, 2, 1)

        self._save_button = QPushButton("Annotate again!")
        self._save_button.clicked.connect(self.annotate_again_clicked)
        grid_layout.addWidget(self._save_button, 2, 3)

        # Bridge for opencv conversion
        self.bridge = CvBridge()

        # Set subscriber to None
        self._sub = None
        self._srv = None

        self.labels = []
        self.label = ""
        self.output_directory = ""

    def image_roi_callback(self, roi_image):
        """
        Callback from the image widget when the user has selected a ROI
        :param roi_image: The opencv image of the ROI
        """
        if not self.labels:
            warning_dialog(
                "No labels specified!",
                "Please first specify some labels using the 'Edit labels' button"
            )
            return

        height, width = roi_image.shape[:2]

        option = option_dialog("Label", self.labels)
        if option:
            self.label = option
            self._image_widget.add_detection(0, 0, width, height, option)
            self.annotate(roi_image)

    def annotate_again_clicked(self):
        """
        Triggered when button clicked
        """
        roi_image = self._image_widget.get_roi_image()
        if roi_image is not None:
            self.annotate(roi_image)

    def annotate(self, roi_image):
        """
        Create an annotation
        :param roi_image: The image we want to annotate
        """
        self.annotate_srv(roi_image)
        self.store_image(roi_image)

    def annotate_srv(self, roi_image):
        """
        Call the selected Annotate.srv
        :param roi_image: The full opencv image we want to annotate
        """
        if roi_image is not None and self.label is not None and self._srv is not None:
            height, width = roi_image.shape[:2]
            try:
                self._srv(image=self.bridge.cv2_to_imgmsg(roi_image, "bgr8"),
                          annotations=[
                              Annotation(label=self.label,
                                         roi=RegionOfInterest(x_offset=0,
                                                              y_offset=0,
                                                              width=width,
                                                              height=height))
                          ])
            except Exception as e:
                warning_dialog("Service Exception", str(e))

    def _create_service_client(self, srv_name):
        """
        Create a service client proxy
        :param srv_name: Name of the service
        """
        if self._srv:
            self._srv.close()

        if srv_name in rosservice.get_service_list():
            rospy.loginfo("Creating proxy for service '%s'" % srv_name)
            self._srv = rospy.ServiceProxy(
                srv_name, rosservice.get_service_class_by_name(srv_name))

    def store_image(self, roi_image):
        """
        Store the image
        :param roi_image: Image we would like to store
        """
        if roi_image is not None and self.label is not None and self.output_directory is not None:
            image_writer.write_annotated(self.output_directory, roi_image,
                                         self.label, True)

    def _get_output_directory(self):
        """
        Gets and sets the output directory via a QFileDialog
        """
        self._set_output_directory(
            QFileDialog.getExistingDirectory(self._widget,
                                             "Select output directory"))

    def _set_output_directory(self, path):
        """
        Sets the output directory
        :param path: The path of the directory
        """
        if not path:
            path = "/tmp"

        self.output_directory = path
        self._output_path_edit.setText("Saving images to %s" % path)

    def _get_labels(self):
        """
        Gets and sets the labels
        """
        text, ok = QInputDialog.getText(
            self._widget, 'Text Input Dialog',
            'Type labels semicolon separated, e.g. banana;apple:',
            QLineEdit.Normal, ";".join(self.labels))
        if ok:
            labels = set([
                _sanitize(label) for label in str(text).split(";")
                if _sanitize(label)
            ])  # Sanitize to alphanumeric, exclude spaces
            self._set_labels(labels)

    def _set_labels(self, labels):
        """
        Sets the labels
        :param labels: label string array
        """
        if not labels:
            labels = []

        self.labels = labels
        self._labels_edit.setText("%s" % labels)

    def _image_callback(self, msg):
        """
        Called when a new sensor_msgs/Image is coming in
        :param msg: The image messaeg
        """
        try:
            cv_image = self.bridge.imgmsg_to_cv2(msg, "bgr8")
        except CvBridgeError as e:
            rospy.logerr(e)

        self._image_widget.set_image(cv_image)

    def trigger_configuration(self):
        """
        Callback when the configuration button is clicked
        """
        topic_name, ok = QInputDialog.getItem(
            self._widget, "Select topic name", "Topic name",
            rostopic.find_by_type('sensor_msgs/Image'))
        if ok:
            self._create_subscriber(topic_name)

        available_rosservices = []
        for s in rosservice.get_service_list():
            try:
                if rosservice.get_service_type(s) in _SUPPORTED_SERVICES:
                    available_rosservices.append(s)
            except:
                pass

        srv_name, ok = QInputDialog.getItem(self._widget,
                                            "Select service name",
                                            "Service name",
                                            available_rosservices)
        if ok:
            self._create_service_client(srv_name)

    def _create_subscriber(self, topic_name):
        """
        Method that creates a subscriber to a sensor_msgs/Image topic
        :param topic_name: The topic_name
        """
        if self._sub:
            self._sub.unregister()
        self._sub = rospy.Subscriber(topic_name, Image, self._image_callback)
        rospy.loginfo("Listening to %s -- spinning .." % self._sub.name)
        self._widget.setWindowTitle("Label plugin, listening to (%s)" %
                                    self._sub.name)

    def shutdown_plugin(self):
        """
        Callback function when shutdown is requested
        """
        pass

    def save_settings(self, plugin_settings, instance_settings):
        """
        Callback function on shutdown to store the local plugin variables
        :param plugin_settings: Plugin settings
        :param instance_settings: Settings of this instance
        """
        instance_settings.set_value("output_directory", self.output_directory)
        instance_settings.set_value("labels", self.labels)
        if self._sub:
            instance_settings.set_value("topic_name", self._sub.name)

    def restore_settings(self, plugin_settings, instance_settings):
        """
        Callback function fired on load of the plugin that allows to restore saved variables
        :param plugin_settings: Plugin settings
        :param instance_settings: Settings of this instance
        """
        path = None
        try:
            path = instance_settings.value("output_directory")
        except:
            pass
        self._set_output_directory(path)

        labels = None
        try:
            labels = instance_settings.value("labels")
        except:
            pass
        self._set_labels(labels)

        self._create_subscriber(
            str(instance_settings.value("topic_name", "/usb_cam/image_raw")))
        self._create_service_client(
            str(
                instance_settings.value("service_name",
                                        "/image_recognition/my_service")))
示例#6
0
class AnnotationPlugin(Plugin):
    def __init__(self, context):
        """
        Annotation plugin to create and edit data sets either by manual annotation or automatically, e.g. using
        a tracker, and generating larger data sets with data augmentation.
        :param context: Parent QT widget
        """
        super(AnnotationPlugin, self).__init__(context)

        # Widget setup
        self.setObjectName('Label Plugin')

        self.widget = QWidget()
        context.add_widget(self.widget)
        self.widget.resize(1800, 1000)
        """left side (current image, grab img button, ...)"""

        self.cur_im_widget = ImageWidget(self.widget,
                                         self.roi_callback,
                                         clear_on_click=False)
        self.cur_im_widget.setGeometry(QRect(20, 20, 640, 480))

        self.grab_img_button = QPushButton(self.widget)
        self.grab_img_button.setText("Grab frame")
        self.grab_img_button.clicked.connect(self.grab_frame)
        self.grab_img_button.setGeometry(QRect(20, 600, 100, 50))
        """right side (selected image, workspace...)"""

        self.sel_im_widget = ImageWidget(self.widget,
                                         self.roi_callback,
                                         clear_on_click=True)
        self.sel_im_widget.setGeometry(QRect(720, 20, 640, 480))
        """list widgets for images and annotations"""

        self.annotation_list_widget = QListWidget(self.widget)
        self.annotation_list_widget.setGeometry(QRect(1400, 50, 150, 200))
        self.annotation_list_widget.setObjectName("annotation_list_widget")
        self.annotation_list_widget.setSelectionMode(
            QAbstractItemView.ExtendedSelection)
        self.annotation_list_widget.currentItemChanged.connect(
            self.select_annotation)

        self.image_list_widget = QListWidget(self.widget)
        self.image_list_widget.setGeometry(QRect(1550, 50, 250, 500))
        self.image_list_widget.setObjectName("image_list_widget")
        self.image_list_widget.setSelectionMode(
            QAbstractItemView.ExtendedSelection)
        self.image_list_widget.currentItemChanged.connect(self.select_image)

        self.output_path_edit = QLineEdit(self.widget)
        self.output_path_edit.setGeometry(QRect(1400, 20, 300, 30))
        self.output_path_edit.setDisabled(True)

        self.edit_path_button = QPushButton(self.widget)
        self.edit_path_button.setText("set ws")
        self.edit_path_button.setGeometry(QRect(1700, 20, 100, 30))
        self.edit_path_button.clicked.connect(self.get_workspace)
        """ buttons for adding or deleting annotations"""
        self.add_annotation_button = QPushButton(self.widget)
        self.add_annotation_button.setText("add")
        self.add_annotation_button.setGeometry(QRect(1400, 250, 75, 30))
        self.add_annotation_button.clicked.connect(self.add_annotation)

        self.remove_annotation_button = QPushButton(self.widget)
        self.remove_annotation_button.setText("del")
        self.remove_annotation_button.setGeometry(QRect(1475, 250, 75, 30))
        self.remove_annotation_button.clicked.connect(
            self.remove_current_annotation)
        """label combo box, line edit and button for adding labels"""
        self.option_selector = QComboBox(self.widget)
        self.option_selector.currentIndexChanged.connect(self.class_change)
        self.option_selector.setGeometry(1400, 280, 150, 30)

        self.label_edit = QLineEdit(self.widget)
        self.label_edit.setGeometry(QRect(1400, 310, 100, 30))
        self.label_edit.setDisabled(False)

        self.edit_label_button = QPushButton(self.widget)
        self.edit_label_button.setText("add")
        self.edit_label_button.setGeometry(QRect(1500, 310, 50, 30))
        self.edit_label_button.clicked.connect(self.add_label)
        """ button for image deletion"""
        self.remove_image_button = QPushButton(self.widget)
        self.remove_image_button.setText("delete image")
        self.remove_image_button.setGeometry(QRect(1550, 550, 150, 30))
        self.remove_image_button.clicked.connect(self.remove_current_image)
        """ export data """
        self.gen_data_label = QLabel(self.widget)
        self.gen_data_label.setText("Export workspace: ")
        self.gen_data_label.setGeometry(QRect(1550, 650, 250, 50))

        self.export_ws_button = QPushButton(self.widget)
        self.export_ws_button.setText("Export")
        self.export_ws_button.setGeometry(QRect(1550, 700, 125, 50))
        self.export_ws_button.clicked.connect(self.export_workspace_to_tf)

        self.conf_export_button = QPushButton(self.widget)
        self.conf_export_button.setText("Configure")
        self.conf_export_button.setGeometry(QRect(1675, 700, 125, 50))
        self.conf_export_button.clicked.connect(self.set_export_parameters)
        """ generate augmented data """

        self.gen_data_label = QLabel(self.widget)
        self.gen_data_label.setText("Generate augmented dataset:")
        self.gen_data_label.setGeometry(QRect(1550, 800, 250, 50))

        self.gen_data_button = QPushButton(self.widget)
        self.gen_data_button.setText("Generate")
        self.gen_data_button.setGeometry(QRect(1550, 850, 125, 50))
        self.gen_data_button.clicked.connect(self.generate_augmented_data)

        self.gen_data_button = QPushButton(self.widget)
        self.gen_data_button.setText("Configure")
        self.gen_data_button.setGeometry(QRect(1675, 850, 125, 50))
        self.gen_data_button.clicked.connect(
            self.set_data_augmentation_parameters)
        """ functional stuff"""
        self.bridge = CvBridge()

        self.sub = None

        self.class_id = -1
        self.label = ""
        self.changes_done = False

        self.workspace = None
        self.labels = []
        self.images_with_annotations = []
        self.cur_annotated_image = None
        self.cur_annotation_index = -1

        self.class_change()

        # export parameters
        self.default_config_path = None
        self.pretrained_graph = None
        self.p_test = 0.2
        self.batch_size = 12

        # data augmentation parameters
        self.gen_dir = None
        self.num_illuminate = 1
        self.num_scale = 1
        self.num_blur = 1

    def set_data_augmentation_parameters(self):
        self.gen_dir = QFileDialog.getExistingDirectory(
            self.widget, "Select output directory")

        num_illum, ok = QInputDialog.getText(self.widget,
                                             "Illumination changes per image",
                                             "1")
        if ok:
            try:
                self.num_illuminate = int(num_illum)
            except ValueError:
                pass
        num_scale, ok = QInputDialog.getText(self.widget,
                                             "Scaling changes per image", "1")
        if ok:
            try:
                self.num_scale = int(num_scale)
            except ValueError:
                pass
        num_blur, ok = QInputDialog.getText(self.widget,
                                            "Blurring changes per image", "1")
        if ok:
            try:
                self.num_blur = int(num_blur)
            except ValueError:
                pass

    def generate_augmented_data(self):
        if self.gen_dir is None:
            warning_dialog("Warning", "Set parameters first")
            return
        data_augmentation.multiply_dataset(self.workspace, self.gen_dir, None,
                                           0, self.num_illuminate,
                                           self.num_scale, self.num_blur, 0)
        print("data augmentation done")

    def set_export_parameters(self):
        """ Set variables for export via QInputDialog and QFileDialog. """
        # default config path
        config_path = QFileDialog.getOpenFileName(self.widget,
                                                  "Select default config")
        config_path = str(config_path[0])
        file, ext = os.path.splitext(config_path)
        if not ext == ".config":
            self.default_config_path = None
            warning_dialog("warning", "invalid file extension")
            return
        else:
            self.default_config_path = config_path

        # pretrained graph
        pretrained_graph_dir = QFileDialog.getExistingDirectory(
            self.widget, "Select pretrained graph directory")
        self.pretrained_graph = pretrained_graph_dir + "/model.ckpt"

        # batch size
        batch_size, ok = QInputDialog.getText(self.widget, "Set batch size",
                                              "12")
        if ok:
            try:
                self.batch_size = int(batch_size)
            except ValueError:
                pass

        # test percentage
        p_test, ok = QInputDialog.getText(self.widget, "Set test percentage",
                                          "0.2")
        if ok:
            try:
                self.p_test = float(p_test)
            except ValueError:
                pass

    def export_workspace_to_tf(self):
        """ Export workspace to training formats. """
        if self.default_config_path is None or self.pretrained_graph is None:
            warning_dialog("Warning",
                           "define default config and pretrained graph first")
            return

        tf_utils.export_data_to_tf(self.workspace,
                                   self.images_with_annotations, self.labels,
                                   self.p_test, self.default_config_path,
                                   self.batch_size, self.pretrained_graph,
                                   True)
        tf_utils.create_roi_images(self.workspace,
                                   self.images_with_annotations, self.labels)
        print("Export done")

    def add_label(self):
        """ If label doesn't exist yet, add to the list and combo box. """
        new_label = str(self.label_edit.text())
        if new_label is None or new_label == "":
            return

        for label in self.labels:
            if label[0] == new_label:
                warning_dialog("warning",
                               "label\"" + label[0] + "\" already exists")
                return
        new_label = list((new_label, 0))
        self.labels.append(new_label)
        self.option_selector.addItem(new_label[0])

        label_file = self.workspace + "/labels.txt"
        utils.write_labels(label_file, self.labels)

    def add_annotation(self):
        """ Add annotation to current image. Bounding box is just a dummy. Use current label. """
        if self.class_id == -1 or self.label == "":
            warning_dialog("Warning", "select label first")
            return
        if self.cur_annotated_image is None:
            warning_dialog("Warning", "select image first")
            return
        label = self.option_selector.currentText()
        annotation = utils.AnnotationWithBbox(self.class_id, 1.0, 0.5, 0.5, 1,
                                              1)
        self.cur_annotated_image.annotation_list.append(annotation)
        index = len(self.cur_annotated_image.annotation_list) - 1
        self.add_annotation_to_list_widget(index, label, True)

        self.changes_done = True

    def class_change(self):
        """ Called another label is selected in the combo box. Set current label and class id. """
        self.label = self.option_selector.currentText()
        if self.labels is None or len(self.labels) == 0:
            return
        self.class_id = [i[0] for i in self.labels].index(self.label)
        # num_annotations = self.labels[self.class_id][1]

    def select_annotation(self):
        """
            Called when an annotation is selected. The corresponding bounding box is drawn thicker,
            enables drawing on the image widget.
        """
        item = self.annotation_list_widget.currentItem()
        if item is None:
            self.sel_im_widget.set_active(False)
            self.sel_im_widget.clear()
            return
        text = str(item.text())
        index = text.split(":")[0]
        self.cur_annotation_index = int(index)
        self.show_image_from_workspace()
        self.sel_im_widget.set_active(True)

    def select_image(self):
        """
            Called when an image from the list is selected.
            Save annotations, if changes were made. Then show selected image and its annotation list.
        """
        if self.changes_done:
            utils.save_annotations(self.cur_annotated_image.image_file,
                                   self.cur_annotated_image.annotation_list)
            self.changes_done = False

        item = self.image_list_widget.currentItem()
        if item is None:
            return
        image_file = self.workspace + str(item.text())
        self.cur_annotated_image = None
        for img in self.images_with_annotations:
            if img.image_file == image_file:
                self.cur_annotated_image = img
        if self.cur_annotated_image is not None:
            self.show_image_from_workspace()
            self.set_annotation_list()

    def set_annotation_list(self):
        """ Set list of annotations for the current image. """
        self.annotation_list_widget.clear()
        self.cur_annotation_index = -1
        for i in range(len(self.cur_annotated_image.annotation_list)):

            index = int(self.cur_annotated_image.annotation_list[i].label)
            num_labels = len(self.labels)
            if index < num_labels:
                self.add_annotation_to_list_widget(i, self.labels[index][0])
            else:
                self.add_annotation_to_list_widget(i, "unknown")
        self.show_image_from_workspace()

    def add_annotation_to_list_widget(self, index, label, select=False):
        """
            Add a QListWidgetItem to the annotation QListWidget
            :param index: index in annotation list
            :param label: label of annotation
            :param select: select new item or not
        """
        item = QListWidgetItem()
        item.setText(str(index) + ":" + label)
        self.annotation_list_widget.addItem(item)
        if select:
            self.annotation_list_widget.setCurrentItem(item)
            self.select_annotation()

    def add_image_to_list_widget(self, file_name, select=False):
        """
            Add a QListWidgetItem to the image QListWidget
            :param file_name: image file name
            :param select: select new item or not
        """
        item = QListWidgetItem()
        item.setText(file_name)
        self.image_list_widget.addItem(item)
        if select:
            self.image_list_widget.setCurrentItem(item)
            self.select_image()

    def refresh_image_list_widget(self):
        """ Clear and fill image list widget """
        self.image_list_widget.clear()
        for img in self.images_with_annotations:
            self.add_image_to_list_widget(
                img.image_file.replace(self.workspace, ""))

    def remove_current_annotation(self):
        """ Remove currently selected annotation """
        if self.cur_annotation_index == -1:
            warning_dialog("Warning", "no annotation selected")
            return
        del (self.cur_annotated_image.annotation_list[
            self.cur_annotation_index])
        self.set_annotation_list()

    def remove_current_image(self):
        """ Remove currently selected image. Remove ItemWidget and delete files """
        if self.cur_annotated_image is None:
            warning_dialog("Warning", "no image selected")
            return
        image_file = self.cur_annotated_image.image_file
        label_file = image_file.replace("images",
                                        "labels").replace("jpg", "txt")
        self.image_list_widget.removeItemWidget(
            self.image_list_widget.currentItem())
        self.images_with_annotations.remove(self.cur_annotated_image)

        os.remove(image_file)
        if os.path.isfile(label_file):
            os.remove(label_file)

        self.refresh_image_list_widget()

    def get_workspace(self):
        """ Gets and sets the output directory via a QFileDialog. Save before, if changes were done. """
        if self.changes_done:
            utils.save_annotations(self.cur_annotated_image.image_file,
                                   self.cur_annotated_image.annotation_list)
            self.changes_done = False

        self.load_workspace(
            QFileDialog.getExistingDirectory(self.widget,
                                             "Select output directory"))

    def load_workspace(self, path):
        """
        Sets the workspace directory. Checks for missing directories and files,
        then loads all images, annotations & label list.
        :param path: The path of the directory
        """
        if not path:
            path = "/tmp"

        self.workspace = path
        self.output_path_edit.setText(path)

        # clear all lists and references
        self.labels = []
        self.cur_annotation_index = -1
        self.cur_annotated_image = None
        self.images_with_annotations = []
        self.changes_done = False
        self.class_id = -1
        self.label = ""

        # clear combo box and listWidgets
        self.option_selector.clear()
        self.image_list_widget.clear()
        self.annotation_list_widget.clear()

        if utils.check_workspace(self.workspace):
            image_dir = self.workspace + "/images"
            label_file = self.workspace + "/labels.txt"

            self.labels = utils.read_labels(label_file)
            for label in self.labels:
                self.option_selector.addItem(label[0])

            for dirname, dirnames, filenames in os.walk(image_dir):
                for filename in sorted(filenames):
                    image_file = dirname + '/' + filename
                    label_file = image_file.replace(
                        "images",
                        "labels").replace(".jpg",
                                          ".txt").replace(".png", ".txt")
                    annotated_image = utils.read_annotated_image(
                        image_file, label_file)
                    self.images_with_annotations.append(annotated_image)
                    self.add_image_to_list_widget("/images/" + filename)
            print("workspace loaded successfully")

    def grab_frame(self):
        """ Grab current frame, save to workspace and show it."""
        if self.workspace is None or self.workspace == "":
            warning_dialog("Warning", "select workspace first")
            return
        cv_image = self.cur_im_widget.get_image()
        if cv_image is None:
            warning_dialog("warning", "no frame to grab")
            return
        file_name = "/images/{}.jpg".format(
            datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S_%f"))
        new_image = utils.AnnotatedImage(self.workspace + file_name, [])
        self.cur_annotated_image = new_image
        self.images_with_annotations.append(new_image)
        cv2.imwrite(self.workspace + file_name, cv_image)
        self.add_image_to_list_widget(file_name, True)
        self.select_image()
#        self.sel_im_widget.set_image(cv_image, None, None)

    def roi_callback(self):
        """ Called when a roi is drawn on an image_widget. Edit annotation if possible. """
        if self.class_id == -1 or self.label == "":
            warning_dialog("Warning", "select label first")
            self.changes_done = True
            return
        if self.cur_annotation_index < 0:
            warning_dialog("Warning", "select annotation first")
            self.changes_done = True
            return

        # set bounding box
        x_center, y_center, width, height = self.sel_im_widget.get_normalized_roi(
        )
        annotation = self.cur_annotated_image.annotation_list[
            self.cur_annotation_index]
        annotation.bbox = utils.BoundingBox(x_center, y_center, width, height)

        # set selected label
        self.label = self.option_selector.currentText()
        if self.labels is None or len(self.labels) == 0:
            return
        self.class_id = [i[0] for i in self.labels].index(self.label)
        annotation.label = self.class_id

        self.set_annotation_list()
        self.show_image_from_workspace()
        self.sel_im_widget.clear()

        self.changes_done = True

    def show_image_from_workspace(self):
        """ Show current selected image and annotations """
        if self.cur_annotated_image is None:
            return
        img = cv2.imread(self.cur_annotated_image.image_file)
        if img is not None:
            self.sel_im_widget.set_image(
                img, self.cur_annotated_image.annotation_list,
                self.cur_annotation_index)

    def image_callback(self, msg):
        """
        Called when a new sensor_msgs/Image is coming in
        :param msg: The image message
        """
        try:
            cv_image = self.bridge.imgmsg_to_cv2(msg, "bgr8")
        except CvBridgeError as e:
            rospy.logerr(e)

        self.cur_im_widget.set_image(cv_image, None, None)

    def trigger_configuration(self):
        """
        Callback when the configuration button is clicked
        """
        topic_name, ok = QInputDialog.getItem(
            self.widget, "Select topic name", "Topic name",
            rostopic.find_by_type('sensor_msgs/Image'))
        if ok:
            self.create_subscriber(topic_name)

    def create_subscriber(self, topic_name):
        """
        Method that creates a subscriber to a sensor_msgs/Image topic
        :param topic_name: The topic_name
        """
        if self.sub:
            self.sub.unregister()
        self.sub = rospy.Subscriber(topic_name, Image, self.image_callback)
        rospy.loginfo("Listening to %s -- spinning .." % self.sub.name)
        self.widget.setWindowTitle("Label plugin, listening to (%s)" %
                                   self.sub.name)

    def shutdown_plugin(self):
        """
        Callback function when shutdown is requested
        """
        if self.changes_done:
            utils.save_annotations(self.cur_annotated_image.image_file,
                                   self.cur_annotated_image.annotation_list)

    def save_settings(self, plugin_settings, instance_settings):
        """
        Callback function on shutdown to store the local plugin variables
        :param plugin_settings: Plugin settings
        :param instance_settings: Settings of this instance
        """
        instance_settings.set_value("workspace_dir", self.workspace)

    def restore_settings(self, plugin_settings, instance_settings):
        """
        Callback function fired on load of the plugin that allows to restore saved variables
        :param plugin_settings: Plugin settings
        :param instance_settings: Settings of this instance
        """
        workspace = None
        try:
            workspace = instance_settings.value("workspace_dir")
        except:
            pass
        self.load_workspace(workspace)
        self.create_subscriber(
            str(instance_settings.value("topic_name", "/xtion/rgb/image_raw")))