class PoseObservationPublisher(SimulatorPlugin):
    def __init__(self, multibody, camera_link, fov, near, far, noise_exp, topic_prefix='', frequency=30, debug=False):
        super(PoseObservationPublisher, self).__init__('PoseObservationPublisher')
        self.topic_prefix = topic_prefix
        self.publisher = rospy.Publisher('{}/pose_obs'.format(topic_prefix), PSAMsg, queue_size=1, tcp_nodelay=True)
        self.message_templates = {}
        self.multibody   = multibody
        self.camera_link = camera_link
        self.fov         = fov
        self.near        = near
        self.far         = far
        self.noise_exp   = noise_exp
        self._enabled    = True
        self.visualizer  = ROSVisualizer('pose_obs_viz', 'world') if debug else None
        self._last_update = 1000
        self._update_wait = 1.0 / frequency


    def post_physics_update(self, simulator, deltaT):
        """Implements post physics step behavior.
        :type simulator: BasicSimulator
        :type deltaT: float
        """
        if not self._enabled:
            return

        self._last_update += deltaT
        if self._last_update >= self._update_wait:
            self._last_update = 0
            cf_tuple = self.multibody.get_link_state(self.camera_link).worldFrame
            camera_frame = gm.frame3_quaternion(cf_tuple.position.x, cf_tuple.position.y, cf_tuple.position.z, *cf_tuple.quaternion)
            cov_proj = gm.rot_of(camera_frame)[:3, :3]
            inv_cov_proj = cov_proj.T

            out = PSAMsg()

            if self.visualizer is not None:
                self.visualizer.begin_draw_cycle()
                poses = []

            for name, body in simulator.bodies.items():
                if body == self.multibody:
                    continue

                if isinstance(body, MultiBody):
                    poses_to_process = [('{}/{}'.format(name, l), body.get_link_state(l).worldFrame) for l in body.links]
                else:
                    poses_to_process = [(name, body.pose())]

                for pname, pose in poses_to_process:
                    if not pname in self.message_templates:
                        msg = PoseStampedMsg()
                        msg.header.frame_id = pname
                        self.message_templates[pname] = msg
                    else:
                        msg = self.message_templates[pname]

                    obj_pos = gm.point3(*pose.position)
                    c2o  = obj_pos - gm.pos_of(camera_frame)
                    dist = gm.norm(c2o)
                    if dist < self.far and dist > self.near and gm.dot_product(c2o, gm.x_of(camera_frame)) > gm.cos(self.fov * 0.5) * dist:


                        noise = 2 ** (self.noise_exp * dist) - 1
                        (n_quat, )  = np_random_quat_normal(1, 0, noise)
                        (n_trans, ) = np_random_normal_offset(1, 0, noise)

                        n_pose = pb.Transform(pb.Quaternion(*pose.quaternion), pb.Vector3(*pose.position)) *\
                                     pb.Transform(pb.Quaternion(*n_quat), pb.Vector3(*n_trans[:3]))

                        if self.visualizer is not None:
                            poses.append(transform_to_matrix(n_pose))
                        msg.pose.position.x = n_pose.origin.x
                        msg.pose.position.y = n_pose.origin.y
                        msg.pose.position.z = n_pose.origin.z
                        msg.pose.orientation.x = n_pose.rotation.x
                        msg.pose.orientation.y = n_pose.rotation.y
                        msg.pose.orientation.z = n_pose.rotation.z
                        msg.pose.orientation.w = n_pose.rotation.w
                        out.poses.append(msg)


                self.publisher.publish(out)

            if self.visualizer is not None:
                self.visualizer.draw_poses('debug', gm.se.eye(4), 0.1, 0.02, poses)
                self.visualizer.render()



    def disable(self, simulator):
        """Stops the execution of this plugin.
        :type simulator: BasicSimulator
        """
        self._enabled = False
        self.publisher.unregister()


    def to_dict(self, simulator):
        """Serializes this plugin to a dictionary.
        :type simulator: BasicSimulator
        :rtype: dict
        """
        return {'body': simulator.get_body_id(self.body.bId()),
                'camera_link':  self.camera_link,
                'fov':          self.fov,
                'near':         self.near,
                'far':          self.far,
                'noise_exp':    self.noise_exp,
                'topic_prefix': self.topic_prefix}

    @classmethod
    def factory(cls, simulator, init_dict):
        body = simulator.get_body(init_dict['body'])
        if body is None:
            raise Exception('Body "{}" does not exist in the context of the given simulation.'.format(init_dict['body']))
        return cls(body,
                   init_dict['camera_link'],
                   init_dict['fov'],
                   init_dict['near'],
                   init_dict['far'],
                   init_dict['noise_exp'],
                   init_dict['topic_prefix'])


    def reset(self, simulator):
        pass
예제 #2
0
    vis = ROSVisualizer('axis_vis', 'world')

    az, ay = [Position(x) for x in 'ax az'.split(' ')]
    frame_rpy = frame3_rpy(0, ay, az, point3(0,0,0))

    state = {ay: 0, az: 0}

    points = [point3(0,0,0) + get_rot_vector(frame_rpy.subs({ay: sin(v), az: cos(v)})) for v in [(3.14512 / 25) * x for x in range(51)]]

    vis.begin_draw_cycle('points')
    vis.draw_strip('points', se.eye(4), 0.03, points)
    vis.render('points')

    rospy.sleep(1)

    timer = Time()
    while not rospy.is_shutdown():
        now = Time.now()

        if (now - timer).to_sec() >= 0.02:
            state[ay] = sin(now.to_sec())
            state[az] = cos(now.to_sec())

            frame = frame_rpy.subs(state)
            axis  = get_rot_vector(frame)

            vis.begin_draw_cycle('visuals')
            vis.draw_poses('visuals', se.eye(4), 0.4, 0.02, [frame])
            vis.draw_vector('visuals', pos_of(frame), axis)
            vis.render('visuals')
예제 #3
0
class ROSQPEManager(object):
    def __init__(self,
                 tracker: Kineverse6DQPTracker,
                 model=None,
                 model_path=None,
                 reference_frame='world',
                 urdf_param='/qp_description_check',
                 update_freq=30,
                 observation_alias=None):
        self.tracker = tracker
        self.last_observation = 0
        self.last_update = 0
        self.reference_frame = reference_frame
        self.observation_aliases = {o: o for o in tracker.observation_names}
        if observation_alias is not None:
            for path, alias in observation_alias.items():
                if path in self.observation_aliases:
                    self.observation_aliases[alias] = path

        self.str_controls = {str(s) for s in self.tracker.get_controls()}

        self.vis = ROSVisualizer('~vis', reference_frame)

        if model is not None:
            self.broadcaster = ModelTFBroadcaster_URDF(urdf_param, model_path,
                                                       model, Path('ekf'))
        else:
            self.broadcaster = None

        self.tf_buffer = tf2_ros.Buffer()
        self.listener = tf2_ros.TransformListener(self.tf_buffer)

        self.pub_state = rospy.Publisher('~state_estimate',
                                         ValueMapMsg,
                                         queue_size=1,
                                         tcp_nodelay=True)
        self.sub_observations = rospy.Subscriber('~observations',
                                                 TransformStampedArrayMsg,
                                                 callback=self.cb_obs,
                                                 queue_size=1)
        self.sub_controls = rospy.Subscriber('~controls',
                                             ValueMapMsg,
                                             callback=self.cb_controls,
                                             queue_size=1)

        self.timer = rospy.Timer(rospy.Duration(1 / update_freq),
                                 self.cb_update)

    def cb_obs(self, transform_stamped_array_msg):
        ref_frames = {}

        num_valid_obs = 0
        last_observation = {}

        for trans in transform_stamped_array_msg.transforms:
            if trans.child_frame_id not in self.observation_aliases:
                continue

            matrix = np_frame3_quaternion(
                trans.transform.translation.x, trans.transform.translation.y,
                trans.transform.translation.z, trans.transform.rotation.x,
                trans.transform.rotation.y, trans.transform.rotation.z,
                trans.transform.rotation.w)

            if trans.header.frame_id != self.reference_frame:
                if trans.header.frame_id not in ref_frames:
                    try:
                        ref_trans = self.tf_buffer.lookup_transform(
                            self.reference_frame, trans.header.frame_id,
                            rospy.Time(0))
                        np_ref_trans = np_frame3_quaternion(
                            ref_trans.transform.translation.x,
                            ref_trans.transform.translation.y,
                            ref_trans.transform.translation.z,
                            ref_trans.transform.rotation.x,
                            ref_trans.transform.rotation.y,
                            ref_trans.transform.rotation.z,
                            ref_trans.transform.rotation.w)
                        ref_frames[trans.header.frame_id] = np_ref_trans
                    except (tf2_ros.LookupException,
                            tf2_ros.ConnectivityException,
                            tf2_ros.ExtrapolationException) as e:
                        print(
                            f'Exception raised while looking up {trans.header.frame_id} -> {self.reference_frame}:\n{e}'
                        )
                        break
                else:
                    np_ref_trans = ref_frames[trans.header.frame_id]

                matrix = np_ref_trans.dot(matrix)

            last_observation[self.observation_aliases[
                trans.child_frame_id]] = matrix
            num_valid_obs += 1
            # self.last_observation[trans.child_frame_id] = np_6d_pose_feature(trans.transform.translation.x,
            #                                                                  trans.transform.translation.y,
            #                                                                  trans.transform.translation.z,
            #                                                                  trans.transform.rotation.x,
            #                                                                  trans.transform.rotation.y,
            #                                                                  trans.transform.rotation.z,
            #                                                                  trans.transform.rotation.w)
        else:
            try:
                if num_valid_obs == 0:
                    return

                self.tracker.process_observation(last_observation)
                self.last_observation += 1
            except QPSolverException as e:
                print(
                    f'Solver crashed during observation update. Skipping observation... Error:\n{e}'
                )
                return
        self.vis.begin_draw_cycle('observations')
        # temp_poses = [gm.frame3_axis_angle(feature[3:] / np.sqrt(np.sum(feature[3:]**2)), np.sqrt(np.sum(feature[3:]**2)), feature[:3]) for feature in self.last_observation.values()]
        self.vis.draw_poses('observations', np.eye(4), 0.2, 0.01,
                            last_observation.values())
        # self.vis.draw_poses('observations', np.eye(4), 0.2, 0.01, temp_poses)
        self.vis.render('observations')

    def cb_controls(self, map_message):
        self.last_control = {
            gm.Symbol(str_symbol): v
            for str_symbol, v in zip(map_message.symbol, map_message.value)
            if str_symbol in self.str_controls
        }
        self.tracker.process_control(self.last_control)

    def cb_update(self, *args):
        if self.last_observation == self.last_update:
            return

        self.last_update = self.last_observation

        est_state = self.tracker.get_estimated_state()

        state_msg = ValueMapMsg()
        state_msg.header.stamp = rospy.Time.now()
        state_msg.symbol, state_msg.value = zip(*est_state.items())
        state_msg.symbol = [str(s) for s in state_msg.symbol]
        self.pub_state.publish(state_msg)

        if self.broadcaster is not None:
            self.broadcaster.update_state(est_state)
            self.broadcaster.publish_state()
예제 #4
0
class ROSEKFManager(object):
    def __init__(self,
                 tracker: Kineverse6DEKFTracker,
                 model=None,
                 model_path=None,
                 urdf_param='/ekf_description_check'):
        self.tracker = tracker
        self.last_observation = {}
        self.last_update = None
        self.last_control = {s: 0.0 for s in self.tracker.controls}
        self.str_controls = {str(s) for s in self.tracker.controls}

        self.vis = ROSVisualizer('~vis', 'world')

        if model is not None:
            self.broadcaster = ModelTFBroadcaster_URDF(urdf_param, model_path,
                                                       model, Path('ekf'))
        else:
            self.broadcaster = None

        self.pub_state = rospy.Publisher('~state_estimate',
                                         ValueMapMsg,
                                         queue_size=1,
                                         tcp_nodelay=True)
        self.sub_observations = rospy.Subscriber('~observations',
                                                 TransformStampedArrayMsg,
                                                 callback=self.cb_obs,
                                                 queue_size=1)
        self.sub_controls = rospy.Subscriber('~controls',
                                             ValueMapMsg,
                                             callback=self.cb_controls,
                                             queue_size=1)

    def cb_obs(self, transform_stamped_array_msg):
        for trans in transform_stamped_array_msg.transforms:
            matrix = np_frame3_quaternion(
                trans.transform.translation.x, trans.transform.translation.y,
                trans.transform.translation.z, trans.transform.rotation.x,
                trans.transform.rotation.y, trans.transform.rotation.z,
                trans.transform.rotation.w)
            self.last_observation[trans.child_frame_id] = matrix
            # self.last_observation[trans.child_frame_id] = np_6d_pose_feature(trans.transform.translation.x,
            #                                                                  trans.transform.translation.y,
            #                                                                  trans.transform.translation.z,
            #                                                                  trans.transform.rotation.x,
            #                                                                  trans.transform.rotation.y,
            #                                                                  trans.transform.rotation.z,
            #                                                                  trans.transform.rotation.w)

        self.vis.begin_draw_cycle('observations')
        # temp_poses = [gm.frame3_axis_angle(feature[3:] / np.sqrt(np.sum(feature[3:]**2)), np.sqrt(np.sum(feature[3:]**2)), feature[:3]) for feature in self.last_observation.values()]
        self.vis.draw_poses('observations', np.eye(4), 0.2, 0.01,
                            self.last_observation.values())
        self.vis.render('observations')
        self.try_update()

    def cb_controls(self, map_message):
        self.last_control = {
            gm.Symbol(str_symbol): v
            for str_symbol, v in zip(map_message.symbol, map_message.value)
            if str_symbol in self.str_controls
        }
        self.try_update()

    def try_update(self):
        if min(p in self.last_observation
               for p in self.tracker.observation_names) is False:
            return

        if len(self.tracker.controls) > 0 and min(
                s in self.last_control
                for s in self.tracker.controls) is False:
            return

        now = rospy.Time.now()
        dt = 0.05 if self.last_update is None else (now -
                                                    self.last_update).to_sec()

        self.tracker.process_update(self.last_observation, self.last_control,
                                    dt)

        est_state = self.tracker.get_estimated_state()
        if len(est_state) == 0:
            return

        state_msg = ValueMapMsg()
        state_msg.header.stamp = now
        state_msg.symbol, state_msg.value = zip(*est_state.items())
        state_msg.symbol = [str(s) for s in state_msg.symbol]
        self.pub_state.publish(state_msg)

        # print('Performed update of estimate and published it.')

        # self.last_control = {}
        self.last_observation = {}

        if self.broadcaster is not None:
            self.broadcaster.update_state(self.tracker.get_estimated_state())
            self.broadcaster.publish_state()