예제 #1
0
class GazeEstimatorROS(GazeEstimatorBase):
    def __init__(self, device_id_gaze, model_files):
        super(GazeEstimatorROS, self).__init__(device_id_gaze, model_files)
        self.bridge = CvBridge()
        self.subjects_bridge = SubjectListBridge()

        self.tf_broadcaster = TransformBroadcaster()
        self.tf_listener = TransformListener()

        self.tf_prefix = rospy.get_param("~tf_prefix", "gaze")
        self.headpose_frame = self.tf_prefix + "/head_pose_estimated"
        self.ros_tf_frame = rospy.get_param("~ros_tf_frame",
                                            "/kinect2_ros_frame")

        self.image_subscriber = rospy.Subscriber("/subjects/images",
                                                 MSG_SubjectImagesList,
                                                 self.image_callback,
                                                 queue_size=3,
                                                 buff_size=2**24)
        self.subjects_gaze_img = rospy.Publisher("/subjects/gazeimages",
                                                 Image,
                                                 queue_size=3)

        self.visualise_eyepose = rospy.get_param("~visualise_eyepose",
                                                 default=True)

        self._last_time = rospy.Time().now()

    def publish_image(self, image, image_publisher, timestamp):
        """This image publishes the `image` to the `image_publisher` with the given `timestamp`."""
        image_ros = self.bridge.cv2_to_imgmsg(image, "rgb8")
        image_ros.header.stamp = timestamp
        image_publisher.publish(image_ros)

    def image_callback(self, subject_image_list):
        """This method is called whenever new input arrives. The input is first converted in a format suitable
        for the gaze estimation network (see :meth:`input_from_image`), then the gaze is estimated (see
        :meth:`estimate_gaze`. The estimated gaze is overlaid on the input image (see :meth:`visualize_eye_result`),
        and this image is published along with the estimated gaze vector (see :meth:`publish_image` and
        :func:`publish_gaze`)"""
        timestamp = subject_image_list.header.stamp

        subjects_dict = self.subjects_bridge.msg_to_images(subject_image_list)
        input_r_list = []
        input_l_list = []
        input_head_list = []
        valid_subject_list = []
        for subject_id, s in subjects_dict.items():
            try:
                (trans_head, rot_head) = self.tf_listener.lookupTransform(
                    self.ros_tf_frame, self.headpose_frame + str(subject_id),
                    timestamp)
                euler_angles_head = list(
                    tf.transformations.euler_from_quaternion(rot_head))
                euler_angles_head = gaze_tools.limit_yaw(euler_angles_head)

                phi_head, theta_head = gaze_tools.get_phi_theta_from_euler(
                    euler_angles_head)
                input_head_list.append([theta_head, phi_head])
                input_r_list.append(self.input_from_image(s.right))
                input_l_list.append(self.input_from_image(s.left))
                valid_subject_list.append(subject_id)
            except (tf.LookupException, tf.ConnectivityException,
                    tf.ExtrapolationException, tf.Exception):
                pass

        if len(valid_subject_list) == 0:
            return

        gaze_est = self.estimate_gaze_twoeyes(
            inference_input_left_list=input_l_list,
            inference_input_right_list=input_r_list,
            inference_headpose_list=input_head_list)

        subjects_gaze_img_list = []
        for subject_id, gaze in zip(valid_subject_list, gaze_est.tolist()):
            self.publish_gaze(gaze, timestamp, subject_id)

            if self.visualise_eyepose:
                s = subjects_dict[subject_id]
                r_gaze_img = self.visualize_eye_result(s.right, gaze)
                l_gaze_img = self.visualize_eye_result(s.left, gaze)
                s_gaze_img = np.concatenate((r_gaze_img, l_gaze_img), axis=1)
                subjects_gaze_img_list.append(s_gaze_img)

        if len(subjects_gaze_img_list) > 0:
            gaze_img_msg = self.bridge.cv2_to_imgmsg(
                np.hstack(subjects_gaze_img_list).astype(np.uint8), "bgr8")
            gaze_img_msg.header.stamp = timestamp
            self.subjects_gaze_img.publish(gaze_img_msg)

        _now = rospy.Time().now()
        _freq = 1.0 / (_now - self._last_time).to_sec()
        self._last_time = _now
        tqdm.write(
            '\033[2K\033[1;32mTime now: {:.2f} message color: {:.2f} diff: {:.2f}s for {} subjects {:.0f}Hz\033[0m'
            .format((_now.to_sec()), timestamp.to_sec(),
                    _now.to_sec() - timestamp.to_sec(),
                    len(valid_subject_list), _freq),
            end="\r")

    def publish_gaze(self, est_gaze, msg_stamp, subject_id):
        """Publish the gaze vector as a PointStamped."""
        theta_gaze = est_gaze[0]
        phi_gaze = est_gaze[1]
        euler_angle_gaze = gaze_tools.get_euler_from_phi_theta(
            phi_gaze, theta_gaze)
        quaternion_gaze = tf.transformations.quaternion_from_euler(
            *euler_angle_gaze)
        self.tf_broadcaster.sendTransform(
            (0, 0,
             0.05),  # publish it 5cm above the head pose's origin (nose tip)
            quaternion_gaze,
            msg_stamp,
            self.tf_prefix + "/world_gaze" + str(subject_id),
            self.headpose_frame + str(subject_id))
예제 #2
0
class GazeEstimatorROS(object):
    def __init__(self, device_id_gaze, model_files):
        self.bridge = CvBridge()
        self.subjects_bridge = SubjectListBridge()

        self.tf2_broadcaster = tf2_ros.TransformBroadcaster()
        self.tf2_buffer = tf2_ros.Buffer()
        self.tf2_listener = tf2_ros.TransformListener(self.tf2_buffer)

        self.tf_prefix = rospy.get_param("~tf_prefix", "gaze")
        self.headpose_frame = self.tf_prefix + "/head_pose_estimated"
        self.gaze_backend = rospy.get_param("~gaze_backend", "tensorflow")

        if self.gaze_backend == "tensorflow":
            from rt_gene.estimate_gaze_tensorflow import GazeEstimator
            self._gaze_estimator = GazeEstimator(device_id_gaze, model_files)
        elif self.gaze_backend == "pytorch":
            from rt_gene.estimate_gaze_pytorch import GazeEstimator
            self._gaze_estimator = GazeEstimator(device_id_gaze, model_files)
        else:
            raise ValueError(
                "Incorrect gaze_base backend, choices are: tensorflow or pytorch"
            )

        self.image_subscriber = rospy.Subscriber("/subjects/images",
                                                 MSG_SubjectImagesList,
                                                 self.image_callback,
                                                 queue_size=3,
                                                 buff_size=2**24)
        self.subjects_gaze_img = rospy.Publisher("/subjects/gazeimages",
                                                 Image,
                                                 queue_size=3)
        self.gaze_publishers = rospy.Publisher("/subjects/gaze",
                                               MSG_GazeList,
                                               queue_size=3)

        self.visualise_eyepose = rospy.get_param("~visualise_eyepose",
                                                 default=True)

        self._last_time = rospy.Time().now()
        self._freq_deque = collections.deque(
            maxlen=30)  # average frequency statistic over roughly one second
        self._latency_deque = collections.deque(maxlen=30)

    def publish_image(self, image, image_publisher, timestamp):
        """This image publishes the `image` to the `image_publisher` with the given `timestamp`."""
        image_ros = self.bridge.cv2_to_imgmsg(image, "rgb8")
        image_ros.header.stamp = timestamp
        image_publisher.publish(image_ros)

    def image_callback(self, subject_image_list):
        """This method is called whenever new input arrives. The input is first converted in a format suitable
        for the gaze estimation network (see :meth:`input_from_image`), then the gaze is estimated (see
        :meth:`estimate_gaze`. The estimated gaze is overlaid on the input image (see :meth:`visualize_eye_result`),
        and this image is published along with the estimated gaze vector (see :meth:`publish_image` and
        :func:`publish_gaze`)"""
        timestamp = subject_image_list.header.stamp
        camera_frame = subject_image_list.header.frame_id

        subjects_dict = self.subjects_bridge.msg_to_images(subject_image_list)
        input_r_list = []
        input_l_list = []
        input_head_list = []
        valid_subject_list = []
        for subject_id, s in subjects_dict.items():
            try:
                transform_msg = self.tf2_buffer.lookup_transform(
                    camera_frame, self.headpose_frame + str(subject_id),
                    timestamp)
                rot_head = transform_msg.transform.rotation
                _m = transformations.quaternion_matrix(
                    [rot_head.x, rot_head.y, rot_head.z, rot_head.w])
                euler_angles_head = list(
                    transformations.euler_from_matrix(
                        np.dot(ros_tools.camera_to_ros, _m)))

                euler_angles_head = gaze_tools.limit_yaw(euler_angles_head)

                phi_head, theta_head = gaze_tools.get_phi_theta_from_euler(
                    euler_angles_head)
                input_head_list.append([theta_head, phi_head])
                input_r_list.append(
                    self._gaze_estimator.input_from_image(s.right))
                input_l_list.append(
                    self._gaze_estimator.input_from_image(s.left))
                valid_subject_list.append(subject_id)
            except (tf2_ros.LookupException, tf2_ros.ConnectivityException,
                    tf2_ros.ExtrapolationException,
                    tf2_ros.TransformException):
                pass

        if len(valid_subject_list) == 0:
            return

        gaze_est = self._gaze_estimator.estimate_gaze_twoeyes(
            inference_input_left_list=input_l_list,
            inference_input_right_list=input_r_list,
            inference_headpose_list=input_head_list)
        subject_subset = dict((k, subjects_dict[k]) for k in valid_subject_list
                              if k in subjects_dict)
        self.publish_gaze_msg(subject_image_list.header, subject_subset,
                              gaze_est.tolist())

        subjects_gaze_img_list = []
        for subject_id, gaze in zip(valid_subject_list, gaze_est.tolist()):
            subjects_dict[subject_id].gaze = gaze
            self.publish_gaze(gaze, timestamp, subject_id)

            if self.visualise_eyepose:
                s = subjects_dict[subject_id]
                r_gaze_img = self._gaze_estimator.visualize_eye_result(
                    s.right, gaze)
                l_gaze_img = self._gaze_estimator.visualize_eye_result(
                    s.left, gaze)
                s_gaze_img = np.concatenate((r_gaze_img, l_gaze_img), axis=1)
                subjects_gaze_img_list.append(s_gaze_img)

        if len(subjects_gaze_img_list) > 0:
            gaze_img_msg = self.bridge.cv2_to_imgmsg(
                np.hstack(subjects_gaze_img_list).astype(np.uint8), "bgr8")
            gaze_img_msg.header.stamp = timestamp
            self.subjects_gaze_img.publish(gaze_img_msg)

        _now = rospy.Time().now()
        _freq = 1.0 / (_now - self._last_time).to_sec()
        self._freq_deque.append(_freq)
        self._latency_deque.append(_now.to_sec() - timestamp.to_sec())
        self._last_time = _now
        tqdm.write(
            '\033[2K\033[1;32mTime now: {:.2f} message color: {:.2f} latency: {:.2f}s for {} subject(s) {:.0f}Hz\033[0m'
            .format((_now.to_sec()), timestamp.to_sec(),
                    np.mean(self._latency_deque), len(valid_subject_list),
                    np.mean(self._freq_deque)),
            end="\r")

    def publish_gaze_msg(self, header, subjects, gazes):
        gaze_msg_list = MSG_GazeList()
        gaze_msg_list.header = header
        for subjects_id, gaze in zip(subjects.keys(), gazes):
            gaze_msg = MSG_Gaze()
            gaze_msg.subject_id = subjects_id
            gaze_msg.theta = gaze[0]
            gaze_msg.phi = gaze[1]
            gaze_msg_list.subjects.append(gaze_msg)

        self.gaze_publishers.publish(gaze_msg_list)

    def publish_gaze(self, est_gaze, msg_stamp, subject_id):
        """Publish the gaze vector as a PointStamped."""
        theta_gaze = est_gaze[0]
        phi_gaze = est_gaze[1]
        euler_angle_gaze = gaze_tools.get_euler_from_phi_theta(
            phi_gaze, theta_gaze)
        quaternion_gaze = transformations.quaternion_from_euler(
            *euler_angle_gaze)

        t = TransformStamped()
        t.header.stamp = msg_stamp
        t.header.frame_id = self.headpose_frame + str(subject_id)
        t.child_frame_id = self.tf_prefix + "/world_gaze" + str(subject_id)
        t.transform.translation.x = 0
        t.transform.translation.y = 0
        t.transform.translation.z = 0.05  # publish it 5cm above the head pose's origin (nose tip)
        t.transform.rotation.x = quaternion_gaze[0]
        t.transform.rotation.y = quaternion_gaze[1]
        t.transform.rotation.z = quaternion_gaze[2]
        t.transform.rotation.w = quaternion_gaze[3]

        try:
            self.tf2_broadcaster.sendTransform([t])
        except rospy.ROSException as exc:
            if str(exc) == "publish() to a closed topic":
                pass
            else:
                raise exc
예제 #3
0
class GazeEstimator(object):
    """This class encapsulates a deep neural network for gaze estimation.

    It retrieves two image streams, one containing the left eye and another containing the right eye.
    It synchronizes these two images with the estimated head pose.
    The images are then converted in a suitable format, and a forward pass (one per eye) of the deep neural network
    results in the estimated gaze for this frame. The estimated gaze is then published in the (theta, phi) notation.
    Additionally, two images with the gaze overlaid on the eye images are published."""
    def __init__(self):
        self.image_height = rospy.get_param("~image_height", 36)
        self.image_width = rospy.get_param("~image_width", 60)
        self.bridge = CvBridge()
        self.subjects_bridge = SubjectListBridge()

        self.tf_broadcaster = TransformBroadcaster()
        self.tf_listener = TransformListener()

        self.use_last_headpose = rospy.get_param("~use_last_headpose", True)
        self.tf_prefix = rospy.get_param("~tf_prefix", "gaze")
        self.last_phi_head, self.last_theta_head = None, None

        self.rgb_frame_id_ros = rospy.get_param("~rgb_frame_id_ros",
                                                "/kinect2_nonrotated_link")

        self.headpose_frame = self.tf_prefix + "/head_pose_estimated"

        config = tensorflow.ConfigProto(inter_op_parallelism_threads=1,
                                        intra_op_parallelism_threads=1)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.3
        config.log_device_placement = False
        self.sess = tensorflow.Session(config=config)
        set_session(self.sess)

        model_files = rospy.get_param("~model_files")

        self.models = []
        for model_file in model_files:
            tqdm.write('Load model ' + model_file)
            model = load_model(os.path.join(
                rospkg.RosPack().get_path('rt_gene'), model_file),
                               custom_objects={
                                   'accuracy_angle': accuracy_angle,
                                   'angle_loss': angle_loss
                               })
            # noinspection PyProtectedMember
            model._make_predict_function(
            )  # have to initialize before threading
            self.models.append(model)
        tqdm.write('Loaded ' + str(len(self.models)) + ' models')

        self.graph = tensorflow.get_default_graph()

        self.image_subscriber = rospy.Subscriber('/subjects/images',
                                                 MSG_SubjectImagesList,
                                                 self.image_callback,
                                                 queue_size=3)
        self.subjects_gaze_img = rospy.Publisher('/subjects/gazeimages',
                                                 Image,
                                                 queue_size=3)

        self.average_weights = np.array([0.1, 0.125, 0.175, 0.2, 0.4])
        self.gaze_buffer_c = {}

    def __del__(self):
        if self.sess is not None:
            self.sess.close()

    def estimate_gaze_twoeyes(self, test_input_left, test_input_right,
                              headpose):
        test_headpose = headpose.reshape(1, 2)
        with self.graph.as_default():
            predictions = []
            for model in self.models:
                predictions.append(
                    model.predict({
                        'img_input_L': test_input_left,
                        'img_input_R': test_input_right,
                        'headpose_input': test_headpose
                    })[0])
            mean_prediction = np.mean(np.array(predictions), axis=0)
            if len(
                    self.models
            ) == 1:  # only apply offset for single model, not for ensemble models
                mean_prediction[1] += 0.11
            return mean_prediction

    def visualize_eye_result(self, eye_image, est_gaze):
        """Here, we take the original eye eye_image and overlay the estimated gaze."""
        output_image = np.copy(eye_image)

        center_x = self.image_width / 2
        center_y = self.image_height / 2

        endpoint_x, endpoint_y = gaze_tools.get_endpoint(
            est_gaze[0], est_gaze[1], center_x, center_y, 50)

        cv2.line(output_image, (int(center_x), int(center_y)),
                 (int(endpoint_x), int(endpoint_y)), (255, 0, 0))
        return output_image

    def publish_image(self, image, image_publisher, timestamp):
        """This image publishes the `image` to the `image_publisher` with the given `timestamp`."""
        image_ros = self.bridge.cv2_to_imgmsg(image, "rgb8")
        image_ros.header.stamp = timestamp
        image_publisher.publish(image_ros)

    def input_from_image(self, eye_img_msg, flip=False):
        """This method converts an eye_img_msg provided by the landmark estimator, and converts it to a format
        suitable for the gaze network."""
        cv_image = eye_img_msg
        #cv_image = self.bridge.imgmsg_to_cv2(eye_img_msg, "rgb8")
        if flip:
            cv_image = cv2.flip(cv_image, 1)
        currimg = cv_image.reshape(self.image_height,
                                   self.image_width,
                                   3,
                                   order='F')
        currimg = currimg.astype(np.float32)
        # print('currimg.dtype', currimg.dtype)
        # cv2.imwrite('/home/tobias/test_inplace.png', currimg)
        testimg = np.zeros((1, self.image_height, self.image_width, 3))
        testimg[0, :, :, 0] = currimg[:, :, 0] - 103.939
        testimg[0, :, :, 1] = currimg[:, :, 1] - 116.779
        testimg[0, :, :, 2] = currimg[:, :, 2] - 123.68
        return testimg

    def compute_eye_gaze_estimation(self, subject_id, timestamp, input_r,
                                    input_l):
        """
        subject_id : integer,  id of the subject
        input_x    : cv_image, input image of x eye
        (phi_x)    : double,   phi angle estimated using pupil detection
        (theta_x)  : double,   theta angle estimated using pupil detection
        """
        try:
            lct = self.tf_listener.getLatestCommonTime(
                self.rgb_frame_id_ros, self.headpose_frame + str(subject_id))
            if (timestamp - lct).to_sec() < 0.25:
                tqdm.write('Time diff: ' + str((timestamp - lct).to_sec()))

                (trans_head, rot_head) = self.tf_listener.lookupTransform(
                    self.rgb_frame_id_ros,
                    self.headpose_frame + str(subject_id), lct)
                euler_angles_head = gaze_tools.get_head_pose(
                    trans_head, rot_head)

                phi_head, theta_head = gaze_tools.get_phi_theta_from_euler(
                    euler_angles_head)
                self.last_phi_head, self.last_theta_head = phi_head, theta_head
            else:
                if self.use_last_headpose and self.last_phi_head is not None:
                    tqdm.write('Big time diff, use last known headpose! ' +
                               str((timestamp - lct).to_sec()))
                    phi_head, theta_head = self.last_phi_head, self.last_theta_head
                else:
                    tqdm.write(
                        'Too big time diff for head pose, do not estimate gaze!'
                        + str((timestamp - lct).to_sec()))
                    return

            start_time = time.time()

            est_gaze_c = self.estimate_gaze_twoeyes(
                input_l, input_r, np.array([theta_head, phi_head]))

            self.gaze_buffer_c[subject_id].append(est_gaze_c)

            if len(self.average_weights) == len(
                    self.gaze_buffer_c[subject_id]):
                est_gaze_c_med = np.average(np.array(
                    self.gaze_buffer_c[subject_id]),
                                            axis=0,
                                            weights=self.average_weights)
                self.publish_gaze(est_gaze_c_med, timestamp, subject_id)
                tqdm.write('est_gaze_c: ' + str(est_gaze_c_med))
                return est_gaze_c_med

            tqdm.write('Elapsed: ' + str(time.time() - start_time))
        except (tf.LookupException, tf.ConnectivityException,
                tf.ExtrapolationException, tf.Exception) as tf_e:
            print(tf_e)
        except rospy.ROSException as ros_e:
            if str(ros_e) == "publish() to a closed topic":
                print("See ya")
        return None

    def image_callback(self, subject_image_list, masked_list=None):
        """This method is called whenever new input arrives. The input is first converted in a format suitable
        for the gaze estimation network (see :meth:`input_from_image`), then the gaze is estimated (see
        :meth:`estimate_gaze`. The estimated gaze is overlaid on the input image (see :meth:`visualize_eye_result`),
        and this image is published along with the estimated gaze vector (see :meth:`publish_image` and
        :func:`publish_gaze`)"""
        timestamp = subject_image_list.header.stamp
        subjects_gaze_img = None

        subjects_dict = self.subjects_bridge.msg_to_images(subject_image_list)
        for subject_id, s in subjects_dict.items():
            if subject_id not in self.gaze_buffer_c.keys():
                self.gaze_buffer_c[subject_id] = collections.deque(maxlen=5)

            input_r = self.input_from_image(s.right, flip=False)
            input_l = self.input_from_image(s.left, flip=False)
            gaze_est = self.compute_eye_gaze_estimation(
                subject_id, timestamp, input_r, input_l)

            if gaze_est is not None:
                r_gaze_img = self.visualize_eye_result(s.right, gaze_est)
                l_gaze_img = self.visualize_eye_result(s.left, gaze_est)
                s_gaze_img = np.concatenate((r_gaze_img, l_gaze_img), axis=1)
                if subjects_gaze_img is None:
                    subjects_gaze_img = s_gaze_img
                else:
                    subjects_gaze_img = np.concatenate(
                        (subjects_gaze_img, s_gaze_img), axis=0)

        if subjects_gaze_img is not None:
            gaze_img_msg = self.bridge.cv2_to_imgmsg(
                subjects_gaze_img.astype(np.uint8), "bgr8")
            self.subjects_gaze_img.publish(gaze_img_msg)

    def publish_gaze(self, est_gaze, msg_stamp, subject_id):
        """Publish the gaze vector as a PointStamped."""
        theta_gaze = est_gaze[0]
        phi_gaze = est_gaze[1]
        euler_angle_gaze = gaze_tools.get_euler_from_phi_theta(
            phi_gaze, theta_gaze)
        quaternion_gaze = tf.transformations.quaternion_from_euler(
            *euler_angle_gaze)
        self.tf_broadcaster.sendTransform(
            (0, 0,
             0.05),  # publish it 5cm above the head pose's origin (nose tip)
            quaternion_gaze,
            msg_stamp,
            self.tf_prefix + "/world_gaze" + str(subject_id),
            self.headpose_frame + str(subject_id))
예제 #4
0
class BlinkEstimatorROS(object):
    def __init__(self, device_id_blink, model_files, threshold):
        self.cv_bridge = CvBridge()
        self.bridge = SubjectListBridge()
        self.viz = rospy.get_param("~viz", True)

        blink_backend = rospy.get_param("~blink_backend", default="pytorch")
        model_type = rospy.get_param("~model_type", default="resnet18")

        if blink_backend == "tensorflow":
            from rt_bene.estimate_blink_tensorflow import BlinkEstimatorTensorflow
            self._blink_estimator = BlinkEstimatorTensorflow(device_id_blink, model_files, model_type, threshold)
        elif blink_backend == "pytorch":
            from rt_bene.estimate_blink_pytorch import BlinkEstimatorPytorch
            self._blink_estimator = BlinkEstimatorPytorch(device_id_blink, model_files, model_type, threshold)
        else:
            raise ValueError("Incorrect gaze_base backend, choices are: tensorflow or pytorch")

        self._last_time = rospy.Time().now()
        self._freq_deque = collections.deque(maxlen=30)  # average frequency statistic over roughly one second
        self._latency_deque = collections.deque(maxlen=30)

        self.blink_publisher = rospy.Publisher("/subjects/blink", MSG_BlinkList, queue_size=3)
        if self.viz:
            self.viz_pub = rospy.Publisher(rospy.get_param("~viz_topic", "/subjects/blink_images"), Image, queue_size=3)

        self.sub = rospy.Subscriber("/subjects/images", MSG_SubjectImagesList, self.callback, queue_size=1,
                                    buff_size=2 ** 24)

    def callback(self, msg):
        subjects = self.bridge.msg_to_images(msg)
        left_eyes = []
        right_eyes = []

        for subject in subjects.values():
            _left, _right = self._blink_estimator.inputs_from_images(subject.left, subject.right)
            left_eyes.append(_left)
            right_eyes.append(_right)

        if len(left_eyes) == 0:
            return

        probs = self._blink_estimator.predict(left_eyes, right_eyes)

        self.publish_msg(msg.header, subjects, probs)

        if self.viz:
            blink_image_list = []
            for subject, p in zip(subjects.values(), probs):
                resized_face = cv2.resize(subject.face, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
                blink_image_list.append(self._blink_estimator.overlay_prediction_over_img(resized_face, p))

            if len(blink_image_list) > 0:
                blink_viz_img = self.cv_bridge.cv2_to_imgmsg(np.hstack(blink_image_list), encoding="bgr8")
                blink_viz_img.header.stamp = msg.header.stamp
                self.viz_pub.publish(blink_viz_img)

        _now = rospy.Time().now()
        timestamp = msg.header.stamp

        _freq = 1.0 / (_now - self._last_time).to_sec()
        self._freq_deque.append(_freq)
        self._latency_deque.append(_now.to_sec() - timestamp.to_sec())
        self._last_time = _now
        tqdm.write(
            '\033[2K\033[1;32mTime now: {:.2f} message color: {:.2f} latency: {:.2f}s for {} subject(s) {:.0f}Hz\033[0m'.format(
                (_now.to_sec()), timestamp.to_sec(), np.mean(self._latency_deque), len(subjects),
                np.mean(self._freq_deque)), end="\r")

    def publish_msg(self, header, subjects, probabilities):
        blink_msg_list = MSG_BlinkList()
        blink_msg_list.header = header
        for subject_id, p in zip(subjects.keys(), probabilities):
            blink_msg = MSG_Blink()
            blink_msg.subject_id = str(subject_id)
            blink_msg.blink = bool(p >= self._blink_estimator.threshold)
            blink_msg.probability = p
            blink_msg_list.subjects.append(blink_msg)

        self.blink_publisher.publish(blink_msg_list)
예제 #5
0
class BlinkEstimatorNode(BlinkEstimatorBase):
    def __init__(self, device_id_blink, model_files, threshold):
        super(BlinkEstimatorNode, self).__init__(device_id_blink, model_files,
                                                 threshold, (96, 96))
        self.sub = rospy.Subscriber("/subjects/images",
                                    MSG_SubjectImagesList,
                                    self.callback,
                                    queue_size=1,
                                    buff_size=2**24)
        self.cv_bridge = CvBridge()
        self.bridge = SubjectListBridge()
        self.viz = rospy.get_param("~viz", True)

        self._last_time = rospy.Time().now()
        self._freq_deque = collections.deque(
            maxlen=30)  # average frequency statistic over roughly one second
        self._latency_deque = collections.deque(maxlen=30)

        self.blink_publisher = rospy.Publisher("/subjects/blink",
                                               MSG_BlinkList,
                                               queue_size=3)
        if self.viz:
            self.viz_pub = rospy.Publisher(rospy.get_param(
                "~viz_topic", "/subjects/blink_images"),
                                           Image,
                                           queue_size=3)

    def callback(self, msg):
        subjects = self.bridge.msg_to_images(msg)
        left_eyes = []
        right_eyes = []

        for subject in subjects.values():
            left_eyes.append(self.resize_img(subject.left))
            right_eyes.append(cv2.flip(self.resize_img(subject.right), 1))

        if len(left_eyes) == 0:
            return

        probs, _ = self.predict(left_eyes, right_eyes)

        self.publish_msg(msg.header.stamp, subjects, probs)

        if self.viz:
            blink_image_list = []
            for subject, p in zip(subjects.values(), probs):
                resized_face = cv2.resize(subject.face,
                                          dsize=(224, 224),
                                          interpolation=cv2.INTER_CUBIC)
                blink_image_list.append(
                    self.overlay_prediction_over_img(resized_face, p))

            if len(blink_image_list) > 0:
                blink_viz_img = self.cv_bridge.cv2_to_imgmsg(
                    np.hstack(blink_image_list), encoding="bgr8")
                blink_viz_img.header.stamp = msg.header.stamp
                self.viz_pub.publish(blink_viz_img)

        _now = rospy.Time().now()
        timestamp = msg.header.stamp
        _freq = 1.0 / (_now - self._last_time).to_sec()
        self._freq_deque.append(_freq)
        self._latency_deque.append(_now.to_sec() - timestamp.to_sec())
        self._last_time = _now
        tqdm.write(
            '\033[2K\033[1;32mTime now: {:.2f} message color: {:.2f} latency: {:.2f}s for {} subject(s) {:.0f}Hz\033[0m'
            .format((_now.to_sec()), timestamp.to_sec(),
                    np.mean(self._latency_deque), len(subjects),
                    np.mean(self._freq_deque)),
            end="\r")

    def publish_msg(self, timestamp, subjects, probabilities):
        blink_msg_list = MSG_BlinkList()
        blink_msg_list.header.stamp = timestamp
        blink_msg_list.header.frame_id = '0'
        for subject_id, p in zip(subjects.keys(), probabilities):
            blink_msg = MSG_Blink()
            blink_msg.subject_id = str(subject_id)
            blink_msg.blink = p >= self.threshold
            blink_msg.probability = p
            blink_msg_list.subjects.append(blink_msg)

        self.blink_publisher.publish(blink_msg_list)