예제 #1
0
def load_trajectory_from_bag(file,fast=True,fast_only=False):
	traj = []
	fast_traj = []
	
	bag = Bag(file)
	
	normal_topics = ['raven_state', 'raven_state/array']
	normal_topics = normal_topics + ['/' + topic for topic in normal_topics]
	
	fast_topic = ['raven_state/1000Hz']
	#fast_topic = fast_topic + ['/' + topic for topic in fast_topic]
	
	if fast_only:
		for topic, msg, t in bag.read_messages(topics=fast_topic):
			fast_traj.append(msg)
		return get_trajectory(fast_traj)
	
	for topic, msg, t in bag.read_messages(topics=normal_topics):
		traj.append(msg)
	
	if fast:
		for topic, msg, t in bag.read_messages(topics=fast_topic):
			fast_traj.append(msg)
	
	if fast_traj:
		return get_trajectory(fast_traj,traj)
	else:
		return get_trajectory(traj)
예제 #2
0
def load_trajectory_from_bag(file, fast=True, fast_only=False):
    traj = []
    fast_traj = []

    bag = Bag(file)

    normal_topics = ['raven_state', 'raven_state/array']
    normal_topics = normal_topics + ['/' + topic for topic in normal_topics]

    fast_topic = ['raven_state/1000Hz']
    #fast_topic = fast_topic + ['/' + topic for topic in fast_topic]

    if fast_only:
        for topic, msg, t in bag.read_messages(topics=fast_topic):
            fast_traj.append(msg)
        return get_trajectory(fast_traj)

    for topic, msg, t in bag.read_messages(topics=normal_topics):
        traj.append(msg)

    if fast:
        for topic, msg, t in bag.read_messages(topics=fast_topic):
            fast_traj.append(msg)

    if fast_traj:
        return get_trajectory(fast_traj, traj)
    else:
        return get_trajectory(traj)
예제 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--bag",
                        dest="bag",
                        help="Path to ROS bag.",
                        required=True)
    parser.add_argument("--left_image_0",
                        help="Path to 'left_image_00000.png'")
    parser.add_argument("--right_image_0",
                        help="Path to 'right_image_00000.png'")
    parser.add_argument("--start_time",
                        help="A start time to check",
                        type=float,
                        default=None)
    args = parser.parse_args()

    bridge = CvBridge()
    bag = Bag(args.bag)
    t_start = bag.get_start_time()
    n_msgs = 0

    if args.left_image_0 is None and args.right_image_0 is None:
        raise ValueError(
            "You must provide either left_image_0 or right_image_0")
    isLeft = args.left_image_0 is not None
    ref_img = args.left_image_0 if isLeft else args.right_image_0
    print("Using ref image %s" % ref_img)
    ref_image_0 = cv2.imread(ref_img)[:, :, 0]
    assert ref_image_0 is not None

    topic = '/davis/left/image_raw' if isLeft else '/davis/right/image_raw'
    if args.start_time:
        topic, msg, t = next(
            bag.read_messages(topics=[topic],
                              start_time=rospy.Time(args.start_time +
                                                    t_start)))
        width = msg.width
        height = msg.height
        image = np.asarray(bridge.imgmsg_to_cv2(msg, msg.encoding))
        image = np.reshape(image, (height, width))

        print("Correct? {}".format(bool(np.all(np.isclose(image,
                                                          ref_image_0)))))

    else:
        for topic, msg, t in bag.read_messages(topics=[topic]):
            n_msgs += 1
            width = msg.width
            height = msg.height
            image = np.asarray(bridge.imgmsg_to_cv2(msg, msg.encoding))
            image = np.reshape(image, (height, width))

            if bool(np.all(np.isclose(image, ref_image_0))):
                print("Ref image found at {} ({} from start)".format(
                    t.to_sec(),
                    t.to_sec() - t_start))
                return

        print("Processed {} images, ref not found!".format(n_msgs))
예제 #4
0
def get_start_stamp(bag: rosbag.Bag):
    pos_prev = []
    t_prev = 0
    count = 0
    t_start = 0
    for topic, msg, t in bag.read_messages('/tf'):
        if msg.transforms[0].child_frame_id == "Puck":
            pos = np.array([
                msg.transforms[0].transform.translation.x,
                msg.transforms[0].transform.translation.y,
                msg.transforms[0].transform.translation.z
            ])
            t_sec = msg.transforms[0].header.stamp.to_sec()
            if t_prev != 0:
                vel = np.linalg.norm((pos - pos_prev) / (t_sec - t_prev))
                if vel > 0.1:
                    if count == 0:
                        t_start = t
                    count += 1
                    if count > 10:
                        return t_start
                else:
                    count = 0
            pos_prev = pos
            t_prev = t_sec
def get_info(bag_file, topic_filter=None):
    bag = Bag(bag_file)
    topics = bag.get_type_and_topic_info().topics
    for topic in topics:
        if topic_filter and topics[topic].msg_type not in topic_filter:
            continue
        print("{}: {} Hz".format(topic, round(topics[topic].frequency, 3)))
        print(topics[topic].message_count)
        times = np.ones(shape=bag.get_message_count(topic_filters=topic))
        n = 0
        for _, msg, t in bag.read_messages(topics=topic):
            times[n] = msg.header.stamp.to_sec()
            n += 1
        times = 1 / np.gradient(times)
        times = times[np.where((times > np.percentile(times, 10)) & (times < np.percentile(times, 90)))]
        print("mean: {}, median: {}".format(np.mean(times), np.median(times)))
        print("min: {}, max: {}".format(np.min(times), np.max(times)))
        # plt.scatter(times, np.gradient(times))
        plt.hist(times)
        plt.yscale("log")
        plt.title("{}: {}".format(os.path.basename(bag_file), topic))
        if not os.path.exists("images"):
            os.makedirs("images")
        plt.savefig(os.path.join("images/", "{}.{}.png".format(os.path.basename(bag_file), topic.replace("/", ""))))
        plt.cla()
예제 #6
0
def dump(input_file: Bag, topics: list, output_file: 'file' = None) -> None:
    """
    Dump messages from a bag.

    Args:
        input_file: the input bag to dump topics from
        topics: the topics to dump
        output_file: an optional file to dump to

    Returns:
        None

    """
    # create a progress bar for iterating over the messages in the bag
    with tqdm(total=input_file.get_message_count(
            topic_filters=topics)) as prog:
        # iterate over the messages in this input bag
        for topic, msg, _ in input_file.read_messages(topics=topics):
            # update the progress bar with a single iteration
            prog.update(1)
            # create the line to print
            line = '{} {}\n\n'.format(topic, msg)
            # print the line to the terminal
            print(line)
            # if there is an output file, write the line to it
            if output_file is not None:
                output_file.write(line)
예제 #7
0
 def _assert_bag_valid(self,
                       filename,
                       topics=None,
                       start_time=None,
                       stop_time=None):
     '''
     Open the bagfile at the specified filename and read it to ensure topic limits were
     enforced and the optional topic list and start/stop times are also enforced.
     '''
     bag = Bag(filename)
     topics_dict = bag.get_type_and_topic_info()[1]
     bag_topics = set(topics_dict.keys())
     param_topics = set(self.topic_limits.keys())
     if topics:
         self.assertEqual(bag_topics, set(topics))
     self.assertTrue(bag_topics.issubset(param_topics))
     for topic in topics_dict:
         size = topics_dict[
             topic].message_count * 8  # Calculate stored message size as each message is 8 bytes
         count = topics_dict[topic].message_count
         gen = bag.read_messages(topics=topic)
         _, _, first_time = next(gen)
         last_time = first_time  # in case the next for loop does not execute
         if start_time:
             self.assertGreaterEqual(first_time, start_time)
         for _, _, last_time in gen:  # Read through all messages so last_time is valid
             pass
         if stop_time:
             self.assertLessEqual(last_time, stop_time)
         duration = last_time - first_time
         self._assert_limits_enforced(topic, duration, size, count)
예제 #8
0
def process_bag(bag_in_fn, bag_out_fn, conf_file_fn):
    bag_in = Bag(bag_in_fn)
    bag_out = Bag(bag_out_fn, 'w')

    include_rules, exclude_rules, time_rules = read_rules(conf_file_fn)
    topic_rules = include_rules + exclude_rules

    # Set the time UNIX time at the start of the bag file
    for r in topic_rules:
        r.set_begin_time(bag_in.get_start_time())
    for r in time_rules:
        r.set_begin_time(bag_in.get_start_time())

    # Find start and end times that are actually required
    t_start, t_end = get_begin_end(time_rules, bag_in)
    # Find topics that are actually required
    bag_topics = bag_in.get_type_and_topic_info().topics.keys()
    topics = get_topics(topic_rules, bag_topics)

    messages = bag_in.read_messages(topics=topics,
                                    start_time=t_start,
                                    end_time=t_end,
                                    return_connection_header=True)
    for topic, msg, t, conn_header in messages:
        # Check default enforcement for this message
        if FilterRule.is_tf_topic(topic):
            default = FilterRule.DEFAULT_ENFORCEMENT_TF
        else:
            default = FilterRule.DEFAULT_ENFORCEMENT_TOPIC

        # When default is to include, only check whether the exclusion
        # rules are satisfied, and if all of them are ok, write it out
        if default == FilterRule.INCLUDE:
            # Check exclusions
            ok = True
            for r in exclude_rules:
                if r.msg_match(topic, msg, t):
                    ok = False
        # When default is to exclude, check if the message matches any
        # of the inclusion rules and write it out if it does
        else:
            # Check inclusions
            ok = False
            for r in include_rules:
                if r.msg_match(topic, msg, t):
                    ok = True

        # Time rules
        time_ok = True
        for r in time_rules:
            if not r.is_ok_with(topic, msg, t):
                time_ok = False

        # Write to file
        if ok and time_ok:
            bag_out.write(topic, msg, t, connection_header=conn_header)

    bag_out.close()
예제 #9
0
def extract_depth(
    bag_file: Bag,
    rgb_directory: str,
    camera_info: str,
    depth: str
) -> None:
    """
    Extract depth data from a bag file.

    Args:
        bag_file: the bag file to play
        rgb_directory: the directory to find RGB image to match depths to
        camera_info: the topic to use to read metadata about the camera
        depth: the topic to use to read 32-bit floating point depth measures

    Returns:
        None

    """
    # extract the camera dimensions from the bag
    dims = get_camera_dimensions(bag, camera_info)
    # get the images from a glob
    images = glob.glob(os.path.join(rgb_directory, 'X', 'data', '*.png'))
    # convert the images to numbers
    path_to_int = lambda x: int(os.path.basename(x).replace('.png', ''))
    images = sorted([path_to_int(image) for image in images])
    # create the output directory
    output_dir = os.path.join(rgb_directory, 'D', 'data')
    try:
        os.makedirs(output_dir)
    except FileExistsError:
        pass
    # iterate over the messages
    progress = tqdm(total=len(images))
    for _, msg, time in bag_file.read_messages(topics=depth):
        # if there are no more images left, break out of the loop
        if not len(images):
            break
        # if the time is less than the current image, continue
        if int(str(time)) < images[0]:
            continue
        # update the progress bar
        progress.update(1)
        # get the depth image
        img = get_depth_image(msg.data, dims, as_rgb=False)
        # save the depth image to disk
        output_file = os.path.join(output_dir, '{}-{}.npz'.format(images[0], time))
        np.savez_compressed(output_file, y=img)
        # remove the first item from the list of times
        images.pop(0)
    # close the progress bar
    progress.close()
예제 #10
0
def bag_to_video(
    input_file: Bag,
    output_file: str,
    topic: str,
    fps: float = 30,
    codec: str = 'MJPG',
) -> None:
    """
    Convert a ROS bag with image topic to a video file.

    Args:
        input_file: the bag file to get image data from
        output_file: the path to an output video file to create
        topic: the topic to read image data from
        fps: the frame rate of the video
        codec: the codec to use when outputting to the video file

    Returns:
        None

    """
    # create an empty reference for the output video file
    video = None
    # get the total number of frames to write
    total = input_file.get_message_count(topic_filters=topic)
    # get an iterator for the topic with the frame data
    iterator = input_file.read_messages(topics=topic)
    # iterate over the image messages of the given topic
    for _, msg, _ in tqdm(iterator, total=total):
        # open the video file if it isn't open
        if video is None:
            # create the video codec
            codec = cv2.VideoWriter_fourcc(*codec)
            # open the output video file
            cv_dims = (msg.width, msg.height)
            video = cv2.VideoWriter(output_file, codec, fps, cv_dims)
        # read the image data into a NumPy tensor
        img = get_camera_image(msg.data, (msg.height, msg.width))
        # write the image to the video file
        video.write(img)

    # if the video file is open, close it
    if video is not None:
        video.release()

    # read the start time of the video from the bag file
    start = input_file.get_start_time()
    # set the access and modification times for the video file
    os.utime(output_file, (start, start))
예제 #11
0
def ipfs_download(multihash: Multihash) -> (dict, Bag):
    rospy.wait_for_service('/ipfs/get_file')
    download = rospy.ServiceProxy('/ipfs/get_file', IpfsDownloadFile)
    tmpfile = NamedTemporaryFile(delete=False)
    res = download(multihash, Filepath(tmpfile.name))
    tmpfile.close()
    if not res.success:
        raise Exception(res.error_msg)
    messages = {}
    bag = Bag(tmpfile.name, 'r')
    for topic, msg, timestamp in bag.read_messages():
        if topic not in messages:
            messages[topic] = [msg]
        else:
            messages[topic].append(msg)
    os.unlink(tmpfile.name)
    return (messages, bag)
예제 #12
0
def process_bag(bag_in_fn, bag_out_fn, conf_file_fn):
    bag_in = Bag(bag_in_fn)
    bag_out = Bag(bag_out_fn, 'w')
    topic_rules, tf_rules = RenameRule.parse(conf_file_fn)

    messages = bag_in.read_messages(return_connection_header=True)
    for topic, msg, t, conn_header in messages:
        if topic == '/tf' or topic == '/tf_static':
            new_topic = topic
            new_msg = modify_tf(msg, tf_rules)
        else:
            new_topic = modify_topic(topic, topic_rules)
            # Modify the frame_id in header if header exists
            new_msg = modify_msg(msg, tf_rules)
        bag_out.write(new_topic, new_msg, t, connection_header=conn_header)

    bag_in.close()
    bag_out.close()
예제 #13
0
def writepc2frombag2file(input_bagfile, pc2_topic, out_file_path,start_time=None, end_time=None, msg_count=None):
    try:
        output_file_name = get_output_file_name(input_bagfile) + '_xyz.txt'
        if out_file_path.endswith('/'):
            output_file = out_file_path + output_file_name
        else:
            output_file = out_file_path+'/' +  output_file_name
        
        output_file_fh = open(output_file,'w')
        print('writing output_file at', output_file)

        # check start and end time condition 
        if start_time is not None and end_time is not None:
            assert (end_time - start_time) > 0 , "end_time should be higher than start time"

        if msg_count is None:
            use_msg_count = False
        else:
            use_msg_count = True
            # msg_count is already static casted in arg parse, just checking again for sanity ad if function is re used individially
            msg_count = int(msg_count)
            assert msg_count > 0, "should have positive msg_count"
        
        input_bag = Bag(input_bagfile,'r')
        print('bag load success')
        
        count = 0;
        for topic, msg, ts in input_bag.read_messages(topics=[pc2_topic],start_time=start_time, end_time=end_time):

            output_file_fh.write('msg_timestamp: %f\n' % msg.header.stamp.to_sec())
            output_file_fh.write('%d\n' %msg.width)
            for data_pts in pc2.read_points(msg, field_names=("x", "y", "z", "intensity"), skip_nans=True):
                output_file_fh.write('%f %f %f %d\n' % (data_pts[0], data_pts[1], data_pts[2], data_pts[3]))
            
            count +=1
            print(count,msg_count)         
            if use_msg_count and count >= msg_count:
                break

        input_bag.close()
        output_file_fh.close()

    except Exception as e_init:
        print(traceback.format_exc(e_init))
예제 #14
0
def pinata_rosbag(ipfs_hash: str) -> str:

    data = requests.get(f'https://gateway.pinata.cloud/ipfs/{ipfs_hash}')
    rospy.loginfo(data.status_code)
    if data.status_code != 200:
        rospy.loginfo(f'Error downloading rosbag! {data.content}')
        return

    else:
        tmpfile = NamedTemporaryFile(delete=False)
        tmpfile.write(data.content)
        bag = Bag(tmpfile.name)

        messages = {}
        for topic, msg, timestamp in bag.read_messages():
            if topic not in messages:
                messages[topic] = [msg]
            else:
                messages[topic].append(msg)

        os.unlink(tmpfile.name)
    return (messages, bag)
예제 #15
0
def get_initial_state(bag: rosbag.Bag):
    i = 0
    j = 0
    tf_table = None
    tf_puck_start = None
    tf_puck_10 = None
    for topic, msg, t in bag.read_messages("/tf"):
        if msg.transforms[0].child_frame_id == "Puck":
            if i == 0:
                tf_puck_start = msg.transforms[0]
            if i == 10:
                tf_puck_10 = msg.transforms[0]
                break
            i += 1

        if msg.transforms[0].child_frame_id == "Table":
            if j == 0:
                tf_table = msg.transforms[0]
            j += 1

    puck_start_T = tf2_geometry_msgs.transform_to_kdl(tf_table).Inverse() * \
                   tf2_geometry_msgs.transform_to_kdl(tf_puck_start)
    puck_10_T = tf2_geometry_msgs.transform_to_kdl(tf_table).Inverse() * \
                tf2_geometry_msgs.transform_to_kdl(tf_puck_10)
    _, _, yaw_start = puck_start_T.M.GetRPY()
    _, _, yaw_10 = puck_10_T.M.GetRPY()

    t_start = tf_puck_start.header.stamp.to_sec()
    t_10 = tf_puck_10.header.stamp.to_sec()

    p_start = np.array([puck_start_T.p.x(), puck_start_T.p.y()])
    p_10 = np.array([puck_10_T.p.x(), puck_10_T.p.y()])
    lin_vel_start = (p_10 - p_start) / (t_10 - t_start)
    ang_vel_start = angles.shortest_angular_distance(yaw_start,
                                                     yaw_10) / (t_10 - t_start)

    return np.hstack([p_start,
                      yaw_start]), np.hstack([lin_vel_start,
                                              ang_vel_start]), tf_table
예제 #16
0
def play_images(bag_file: Bag, topics: list) -> None:
    """
    Play the data in a bag file.

    Args:
        bag_file: the bag file to play
        topics: the list of topics to play

    Returns:
        None

    """
    # open windows to stream the camera and a priori image data to
    windows = {topic: None for topic in topics}
    # iterate over the messages
    progress = tqdm(total=bag_file.get_message_count(topic_filters=topics))
    for topic, msg, time in bag_file.read_messages(topics=topics):
        # if topic is camera, unwrap and send to the camera window
        if topic in topics:
            # update the progress bar with an iteration
            progress.update(1)
            # update the progress with a post fix
            progress.set_postfix(time=time)
            # if the camera window isn't open, open it
            if windows[topic] is None:
                title = '{} ({})'.format(bag_file.filename, topic)
                windows[topic] = Window(title, msg.height, msg.width)
            # get the pixels of the camera image and display them
            img = get_camera_image(msg.data, windows[topic].shape)
            if msg.encoding == 'bgr8':
                img = img[..., ::-1]
            windows[topic].show(img[..., :3])

    # shut down the viewer windows
    for window in windows.values():
        if window is not None:
            window.close()
예제 #17
0
outimgs_path = '/home/kaue/data/extracted_images/newextracted'


from geometry_msgs.msg import _TwistStamped

t1 = time.time()

# bagpath = "/home/kauevestena/data/rosbags/2019-07-11-16-21-46.bag"
bagpath = msc.joinToHome('data/rosbags/2019-07-11-16-21-46.bag')
print('opened',bagpath)

current_bag = Bag(bagpath)

# print(current_bag.get_type_and_topic_info())

messages = current_bag.read_messages()

t0 = messages.next().timestamp.to_sec()

topic_names = {
    "imu" : "imu/data",
    "img" : "/raspicam_node/image/compressed",
    "gnss" : "/fix"
}

# imu_list = []
# img_list = []
# gnss_list = []

# # # # the big dict that will store all the relevant data
pickle_outpath = os.path.join(msc.PICKLESPATH,msc.filenameFromPathWtithoutExt(bagpath)+'_idx.pickle')
    def on_enter(self, userdata):
        self._failed = False
        try:
            # Initialization

            # userdata
            self._bag_filename = userdata.bag_filename
            self._trajectories_command = userdata.trajectories_command

            # joint definitions
            l_arm_range = range(16, 23)
            r_arm_range = range(23, 30)
            atlasJointNames = [
                'back_bkz', 'back_bky', 'back_bkx', 'neck_ry', 'l_leg_hpz',
                'l_leg_hpx', 'l_leg_hpy', 'l_leg_kny', 'l_leg_aky',
                'l_leg_akx', 'r_leg_hpz', 'r_leg_hpx', 'r_leg_hpy',
                'r_leg_kny', 'r_leg_aky', 'r_leg_akx', 'l_arm_shz',
                'l_arm_shx', 'l_arm_ely', 'l_arm_elx', 'l_arm_wry',
                'l_arm_wrx', 'l_arm_wry2', 'r_arm_shz', 'r_arm_shx',
                'r_arm_ely', 'r_arm_elx', 'r_arm_wry', 'r_arm_wrx',
                'r_arm_wry2'
            ]
            joint_position_cmd = [0] * len(l_arm_range)
            # starting and stopping time of measurement data to be taken into account for cailbration
            time_start = 0
            time_end = 0
            time_tf = 0

            # gravity vector in world frame
            g_world = np.resize(np.array([0, 0, -9.81]), (3, 1))

            # take this number of data points for each pose.
            max_data_points_per_pose = 100

            # rotation of the force torque sensor in world frame (quaternions)
            ft_rotation = [0, 0, 0, 0]

            # loop over all ft sensors to calibrate

            for chain in self._calibration_chain:
                tf_chain = ['/pelvis', 'ltorso', 'mtorso', 'utorso']
                if chain == 'left_arm':
                    joint_range = l_arm_range
                    tfname = 'l_hand'
                    tf_chain.extend([
                        'l_clav', 'l_scap', 'l_uarm', 'l_larm', 'l_ufarm',
                        'l_lfarm', 'l_hand'
                    ])
                elif chain == 'right_arm':
                    joint_range = r_arm_range
                    tf_chain.extend([
                        'r_clav', 'r_scap', 'r_uarm', 'r_larm', 'r_ufarm',
                        'r_lfarm', 'r_hand'
                    ])

                    tfname = 'r_hand'
                else:
                    Logger.logwarn(
                        'CalculateForceTorqueCalibration: Undefined chain %s',
                        chain)

                # initialize transformations
                tf_data = [
                ]  # frame to frame transformation with numerical indexes from the tf_chain
                tf_data_cum = [
                ]  # world to frame transformation with numerical indexes from the tf_chain

                for i in range(len(tf_chain)):
                    tf_data.append([0, 0, 0, 0])
                    tf_data_cum.append([0, 0, 0, 0])

                # get number of poses from the commanded trajectories. 2 Points per Pose
                number_of_poses = len(
                    self._trajectories_command[chain].points) / 2
                # define information matrix for calibration
                InfMat = np.zeros(
                    (6 * max_data_points_per_pose * number_of_poses, 10))
                InfMat_i = 0  # Index in this Matrix
                # definie measurement vector for calibration
                MeasVec = np.zeros(
                    (6 * max_data_points_per_pose * number_of_poses, 1))

                # commanded joint trajectory for current chain
                # all commanded positions in this trajectory
                current_traj_cmd = self._trajectories_command[chain]

                # flag if the time interval was already set:
                timesetflag = False

                point_index_cmd = 0  # start with index 0
                transformation_available = False

                # read time series from bag file
                # check in desired input trajectory, at which time a position is reached
                # take the data in a defined time period after the commanded new position
                bag_from_robot = Bag(os.path.expanduser(
                    self._bag_filename))  # open bag file
                Logger.loginfo(
                    'CalculateForceTorqueCalibration: Calibrate %s from %s. Using %d different positions from trajectory command. Expecting 2 commanded positions per pose'
                    %
                    (chain, self._bag_filename, len(current_traj_cmd.points)))
                for topic, msg, t in bag_from_robot.read_messages(topics=[
                        '/flor/controller/atlas_state', '/tf',
                        '/flor/controller/joint_command'
                ]):
                    # Check if all desired poses have been reached
                    if point_index_cmd > len(current_traj_cmd.points) - 1:
                        break
                    current_position_cmd = current_traj_cmd.points[
                        point_index_cmd].positions
                    ####################################################
                    # Check tf message: remember last transformation from world to ft sensor
                    if topic == '/tf':
                        # loop over all transformations inside this tf message
                        for j in range(len(msg.transforms)):
                            data = msg.transforms[j]
                            tr = data.transform  # current transformation
                            header = data.header
                            # check if frame matches on of the saved frames
                            for i in range(len(tf_chain)):
                                if data.child_frame_id == tf_chain[i]:
                                    tf_data[i] = np.array([
                                        tr.rotation.x, tr.rotation.y,
                                        tr.rotation.z, tr.rotation.w
                                    ])

                        # calculate the newest transformation to the force torque sensor
                        tf_data_cum[0] = tf_data[0]
                        for i in range(1, len(tf_chain)):
                            tf_data_cum[
                                i] = q = tf.transformations.quaternion_multiply(
                                    tf_data_cum[i - 1], tf_data[i])
                            if i == len(tf_chain) - 1:
                                time_tf = msg.transforms[
                                    0].header.stamp.to_sec()
                        # nothing to do with tf message
                        # check if all transformations are available. (up to the last link)
                        if np.any(
                                tf_data_cum[-1]
                        ):  # real part of quaternion unequal to zero: data exists
                            transformation_available = True
                        continue
                    else:
                        time_msg = msg.header.stamp.to_sec()

                    ####################################################
                    # check if timestamp is interesting. Then check the
                    if not (time_msg > time_start
                            and time_msg < time_end) and not timesetflag:
                        # the timestamp is not in evaluation interval. Look for reaching of the pose
                        # record data (see below)
                        if topic == '/flor/controller/joint_command':
                            # get the commanded position
                            for i in range(len(joint_range)):
                                osrf_ndx = joint_range[i]
                                joint_position_cmd[i] = msg.position[osrf_ndx]

                # Check if position command matches the currently expected commanded position
                            pos_reached = True

                            for i in range(len(joint_range)):
                                if abs(joint_position_cmd[i] -
                                       current_position_cmd[i]) > 0.001:
                                    pos_reached = False
                                    break
                            if pos_reached:
                                # end time of the values regarded for calibration data: take the latest time possible  for this pose
                                time_end = msg.header.stamp.to_sec(
                                ) + self._settlingtime
                                # starting time for calibration: Take 100ms
                                time_start = time_end - 0.1
                                data_points_for_this_pose = 0
                                timesetflag = True
                                # take the next point next time. Each pose consists of two trajectory points (one for reaching, one for settling).
                                # take the second one
                                point_index_cmd = point_index_cmd + 2
                            if not pos_reached:
                                # The commanded position has not been reached. Skip
                                continue
                        continue  # continue here, because data aquisition is triggered by the time_start, time_end

                    if time_msg > time_end:
                        timesetflag = False  # prepare for new evaluation interval

                    ####################################################
                    # check if enough datapoints for this pose have been collected
                    if data_points_for_this_pose > max_data_points_per_pose:
                        # already enough data points for this pose
                        continue

            ####################################################
            # Check if message is atlas_state
            # IF this is the case, fill information matrix
                    if topic != '/flor/controller/atlas_state':
                        continue

                    # Extract measured force and torque
                    if chain == 'left_arm':
                        FT = [
                            msg.l_hand.force.x, msg.l_hand.force.y,
                            msg.l_hand.force.z, msg.l_hand.torque.x,
                            msg.l_hand.torque.y, msg.l_hand.torque.z
                        ]
                    elif chain == 'right_arm':
                        FT = [
                            msg.r_hand.force.x, msg.r_hand.force.y,
                            msg.r_hand.force.z, msg.r_hand.torque.x,
                            msg.r_hand.torque.y, msg.r_hand.torque.z
                        ]

                    # calculate gravitation vector
                    if not transformation_available:
                        Logger.logwarn(
                            'No tf messages available at time %1.4f.' %
                            time_msg)
                        continue

                    R = tf.transformations.quaternion_matrix(tf_data_cum[-1])
                    g = np.dot((R[0:3, 0:3]).transpose(), g_world)

                    gx = g[0]
                    gy = g[1]
                    gz = g[2]

                    # fill information matrix for this data point. Source: [1], equ. (7)
                    M = np.zeros((6, 10))
                    M[0, 0] = gx
                    M[1, 0] = gy
                    M[2, 0] = gz

                    M[3, 2] = gz
                    M[3, 3] = -gy
                    M[4, 1] = -gz
                    M[4, 3] = gx
                    M[5, 1] = gy
                    M[5, 2] = -gx
                    M[0, 4] = 1.0
                    M[1, 5] = 1.0
                    M[2, 6] = 1.0
                    M[3, 7] = 1.0
                    M[4, 8] = 1.0
                    M[5, 9] = 1.0

                    # fill big information matrix and vector (stack small information matrizes)
                    for i in range(6):
                        for j in range(10):
                            InfMat[InfMat_i * 6 + i, j] = M[i, j]
                        MeasVec[InfMat_i * 6 + i, 0] = FT[i]

                    InfMat_i = InfMat_i + 1  # increase index
                    data_points_for_this_pose = data_points_for_this_pose + 1

                # shorten big information matrix
                if InfMat_i < max_data_points_per_pose * number_of_poses:
                    InfMat_calc = InfMat[0:(6 * InfMat_i) - 1, :]
                    MeasVec_calc = MeasVec[0:(6 * InfMat_i - 1), :]
                else:
                    InfMat_calc = InfMat
                    MeasVec_calc = MeasVec

                # calculate calibration data
                if chain in self._static_calibration_data.keys(
                ):  # calculate calibration with given static parameters
                    # bring colums with first parameters on the other side of the equation
                    if len(self._static_calibration_data[chain]) != 4:
                        Logger.logwarn(
                            "CalculateForceTorqueCalibration: Given static calibration data for %s has length %d. Required 4 entries."
                            % (chain, len(k_fix)))
                        self._failed = True
                        return
                    # convert physical parameters to identification parameters (mass, 1st moment)
                    m = self._static_calibration_data[chain][0]
                    if m == 0:
                        mom_x = 0
                        mom_y = 0
                        mom_z = 0
                    elif m > 0:
                        mom_x = self._static_calibration_data[chain][1] / m
                        mom_y = self._static_calibration_data[chain][2] / m
                        mom_z = self._static_calibration_data[chain][3] / m
                    else:
                        Logger.logwarn(
                            "CalculateForceTorqueCalibration: Negative mass (%f) for calibration requested. Abort."
                            % mass)
                        self._failed = True
                        return
                    k_fix = np.resize(np.array([m, mom_x, mom_y, mom_z]),
                                      (4, 1))
                    Logger.loginfo(
                        "CalculateForceTorqueCalibration:static calibration data for %s given: %s. Reduce equation system"
                        % (chain, str(self._static_calibration_data[chain])))
                    MeasVec_calc_corr = np.subtract(
                        np.array(MeasVec_calc),
                        np.dot(InfMat_calc[:, 0:4], k_fix))
                    InfMat_calc_corr = InfMat_calc[:, 4:
                                                   10]  # only the last 6 colums which correspond to the sensor offsets
                    k = np.linalg.lstsq(
                        InfMat_calc_corr,
                        MeasVec_calc_corr)[0]  # solve reduced equation system
                    k_calibration = self._static_calibration_data[chain]
                    k_calibration.extend(k)
                else:  # calculate normally with all parameters unknown
                    k = np.linalg.lstsq(InfMat_calc, MeasVec_calc)[0]
                    # convert to physical parameters (first moment -> mass)
                    if k[0] > 0:
                        k_calibration = [
                            k[0], k[1] / k[0], k[2] / k[0], k[3] / k[0], k[4],
                            k[5], k[6], k[7], k[8], k[9]
                        ]
                    else:
                        Logger.loginfo(
                            "CalculateForceTorqueCalibration:Calibration brought negative mass %f"
                            % k[0])
                        k_calibration = [
                            0.0, 0.0, 0.0, 0.0, k[4], k[5], k[6], k[7], k[8],
                            k[9]
                        ]

                Logger.loginfo(
                    "CalculateForceTorqueCalibration:calibration data for %s" %
                    chain)
                Logger.loginfo("CalculateForceTorqueCalibration:mass: %f" %
                               float(k_calibration[0]))
                Logger.loginfo(
                    "CalculateForceTorqueCalibration:center of mass: %f %f %f"
                    % (k_calibration[1], k_calibration[2], k_calibration[3]))
                Logger.loginfo(
                    "CalculateForceTorqueCalibration:F offset: %f %f %f" %
                    (k_calibration[4], k_calibration[5], k_calibration[6]))
                Logger.loginfo(
                    "CalculateForceTorqueCalibration:M offset: %f %f %f" %
                    (k_calibration[7], k_calibration[8], k_calibration[9]))
                self._ft_calib_data[chain] = k_calibration

                bag_from_robot.close()

            userdata.ft_calib_data = self._ft_calib_data
            Logger.loginfo(
                'CalculateForceTorqueCalibration:Calibration finished')
            self._done = True

        except Exception as e:
            Logger.logwarn(
                'CalculateForceTorqueCalibration:Unable to calculate calibration:\n%s'
                % str(e))
            self._failed = True
예제 #19
0
def main():
    parser = argparse.ArgumentParser(
        description=("Extracts grayscale and event images from a ROS bag and "
                     "saves them as TFRecords for training in TensorFlow."))
    parser.add_argument("--bag", dest="bag",
                        help="Path to ROS bag.",
                        required=True)
    parser.add_argument("--prefix", dest="prefix",
                        help="Output file prefix.",
                        required=True)
    parser.add_argument("--output_folder", dest="output_folder",
                        help="Output folder.",
                        required=True)
    parser.add_argument("--max_aug", dest="max_aug",
                        help="Maximum number of images to combine for augmentation.",
                        type=int,
                        default=6)
    parser.add_argument("--n_skip", dest="n_skip",
                        help="Maximum number of images to combine for augmentation.",
                        type=int,
                        default=1)
    parser.add_argument("--start_time", dest="start_time",
                        help="Time to start in the bag.",
                        type=float,
                        default=0.0)
    parser.add_argument("--end_time", dest="end_time",
                        help="Time to end in the bag.",
                        type=float,
                        default=-1.0)
    parser.add_argument("--save_rgb_images", default=True,
                        const=False, nargs="?")
    parser.add_argument("--debug", default=False,
                        const=True, nargs="?")
    parser.add_argument("--whitelist_imageids_txt",
                        type=str,
                        default=None)

    args = parser.parse_args()

    bridge = CvBridge()

    n_msgs = 0
    left_start_event_offset = 0
    right_start_event_offset = 0

    left_event_image_iter = 0
    right_event_image_iter = 0
    left_image_iter = 0
    right_image_iter = 0
    first_left_image_time = -1
    first_right_image_time = -1

    left_events = []
    right_events = []
    left_images = []
    right_images = []
    left_image_times = []
    right_image_times = []
    left_event_count_images = []
    left_event_time_images = []
    left_event_image_times = []

    right_event_count_images = []
    right_event_time_images = []
    right_event_image_times = []

    left_image_event_start_idx = []
    left_image_event_end_idx = []
    right_image_event_start_idx = []
    right_image_event_end_idx = []

    whitelist_imageids = None
    if args.whitelist_imageids_txt is not None:
        with open(args.whitelist_imageids_txt, 'r') as fp:
            whitelist_imageids = fp.read().splitlines()
        whitelist_imageids = [int(l) for l in whitelist_imageids]

    cols = 346
    rows = 260
    print("Processing bag")
    bag = Bag(args.bag)
    h5_left, h5_right = None, None
    if args.debug:
        import h5py
        h5_file = h5py.File(args.bag[:-len("bag")]+"hdf5")
        h5_left = h5_file['davis']['left']['events']
        h5_right = h5_file['davis']['right']['events']

    options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
    left_tf_writer = tf.python_io.TFRecordWriter(
        os.path.join(args.output_folder, args.prefix, "left_event_images.tfrecord"),
        options=options)
    right_tf_writer = tf.python_io.TFRecordWriter(
        os.path.join(args.output_folder, args.prefix, "right_event_images.tfrecord"),
        options=options)

    # Get actual time for the start of the bag.
    t_start = bag.get_start_time()
    t_start_ros = rospy.Time(t_start)
    # Set the time at which the bag reading should end.
    if args.end_time == -1.0:
        t_end = bag.get_end_time()
    else:
        t_end = t_start + args.end_time

    for topic, msg, t in bag.read_messages(
            topics=['/davis/left/image_raw',
                    '/davis/right/image_raw',
                    '/davis/left/events',
                    '/davis/right/events'],
            # **** MOD **** #
            # NOTE: we always start reading from the start in order
            #       to count the number of events that have to be
            #       discarded in the HDF5 file
            # start_time=rospy.Time(args.start_time + t_start),
            end_time=rospy.Time(t_end)):
        # Check to make sure we're working with stereo messages.
        if not ('left' in topic or 'right' in topic):
            print('ERROR: topic {} does not contain left or right, is this stereo?'
                  'If not, you will need to modify the topic names in the code.'.
                  format(topic))
            return

        n_msgs += 1
        if n_msgs % 500 == 0:
            print("Processed {} msgs, {} images, time is {}.".format(n_msgs,
                                                                     left_event_image_iter,
                                                                     t.to_sec() - t_start))

        isLeft = 'left' in topic
        # **** MOD **** # /*start
        # If we are still not reading the part
        # we are interested in, we just count
        # the number of events
        if t.to_sec() < args.start_time + t_start:
            if 'events' in topic and msg.events:
                if isLeft:
                    left_start_event_offset += len(msg.events)
                else:
                    right_start_event_offset += len(msg.events)
            continue  # read the next msg
        # **** MOD **** # end*/

        if 'image' in topic:
            width = msg.width
            height = msg.height
            if width != cols or height != rows:
                print("Image dimensions are not what we expected: set: ({} {}) vs  got:({} {})"
                      .format(cols, rows, width, height))
                return
            time = msg.header.stamp
            image = np.asarray(bridge.imgmsg_to_cv2(msg, msg.encoding))
            image = np.reshape(image, (height, width))

            if isLeft:
                if whitelist_imageids is None or left_image_iter in whitelist_imageids:
                    if args.save_rgb_images:
                        cv2.imwrite(os.path.join(args.output_folder,
                                                 args.prefix,
                                                 "left_image{:05d}.png".format(left_image_iter)),
                                    image)
                    if left_image_iter > 0:
                        left_image_times.append(time)
                    else:
                        first_left_image_time = time
                        left_event_image_times.append(time.to_sec())
                left_image_iter += 1
            else:
                if whitelist_imageids is None or right_image_iter in whitelist_imageids:
                    if args.save_rgb_images:
                        cv2.imwrite(os.path.join(args.output_folder,
                                                 args.prefix,
                                                 "right_image{:05d}.png".format(right_image_iter)),
                                    image)
                    if right_image_iter > 0:
                        right_image_times.append(time)
                    else:
                        first_right_image_time = time
                        right_event_image_times.append(time.to_sec())
                right_image_iter += 1
        elif 'events' in topic and msg.events:
            for event in msg.events:
                ts = event.ts
                event = [event.x,
                         event.y,
                         (ts - t_start_ros).to_sec(),
                         (float(event.polarity) - 0.5) * 2]
                if isLeft:
                    if first_left_image_time != -1 and ts > first_left_image_time:
                        left_events.append(event)
                    else:
                        left_start_event_offset += 1
                else:
                    if first_right_image_time != -1 and ts > first_right_image_time:
                        right_events.append(event)
                    else:
                        right_start_event_offset += 1
            if isLeft:
                if len(left_image_times) >= args.max_aug and \
                        left_events[-1][2] > (left_image_times[args.max_aug - 1] - t_start_ros).to_sec():
                    left_event_image_iter, consumed = _save_events(left_events,
                                                                   left_image_times,
                                                                   left_event_count_images,
                                                                   left_event_time_images,
                                                                   left_event_image_times,
                                                                   left_start_event_offset,
                                                                   left_image_event_start_idx,
                                                                   left_image_event_end_idx,
                                                                   rows,
                                                                   cols,
                                                                   args.max_aug,
                                                                   args.n_skip,
                                                                   left_event_image_iter,
                                                                   args.prefix,
                                                                   'left',
                                                                   left_tf_writer,
                                                                   t_start_ros,
                                                                   h5_left,
                                                                   whitelist_imageids)
                    left_start_event_offset += consumed
            else:
                if len(right_image_times) >= args.max_aug and \
                        right_events[-1][2] > (right_image_times[args.max_aug - 1] - t_start_ros).to_sec():
                    right_event_image_iter, consumed = _save_events(right_events,
                                                                    right_image_times,
                                                                    right_event_count_images,
                                                                    right_event_time_images,
                                                                    right_event_image_times,
                                                                    right_start_event_offset,
                                                                    right_image_event_start_idx,
                                                                    right_image_event_end_idx,
                                                                    rows,
                                                                    cols,
                                                                    args.max_aug,
                                                                    args.n_skip,
                                                                    right_event_image_iter,
                                                                    args.prefix,
                                                                    'right',
                                                                    right_tf_writer,
                                                                    t_start_ros,
                                                                    h5_right,
                                                                    whitelist_imageids)
                    right_start_event_offset += consumed

    left_tf_writer.close()
    right_tf_writer.close()

    image_counter_file = open(os.path.join(args.output_folder, args.prefix, "n_images.txt"), 'w')
    image_counter_file.write("{} {}".format(left_event_image_iter, right_event_image_iter))
    image_counter_file.close()
	def on_enter(self, userdata):
		self._failed = False
		try:
			# Initialization
		
			# userdata
			self._bag_filename = userdata.bag_filename
			self._trajectories_command = userdata.trajectories_command
			
			# joint definitions
			l_arm_range = range(16,23);
			r_arm_range = range(23,30);
			atlasJointNames = [
			    'back_bkz', 'back_bky', 'back_bkx', 'neck_ry',
			    'l_leg_hpz', 'l_leg_hpx', 'l_leg_hpy', 'l_leg_kny', 'l_leg_aky', 'l_leg_akx',
			    'r_leg_hpz', 'r_leg_hpx', 'r_leg_hpy', 'r_leg_kny', 'r_leg_aky', 'r_leg_akx',
			    'l_arm_shz', 'l_arm_shx', 'l_arm_ely', 'l_arm_elx', 'l_arm_wry', 'l_arm_wrx', 'l_arm_wry2',
			    'r_arm_shz', 'r_arm_shx', 'r_arm_ely', 'r_arm_elx', 'r_arm_wry', 'r_arm_wrx', 'r_arm_wry2']
			joint_position_cmd = [0]*len(l_arm_range)
			# starting and stopping time of measurement data to be taken into account for cailbration
			time_start = 0
			time_end = 0
			time_tf = 0
			
			# gravity vector in world frame
			g_world = np.resize(np.array([0,0,-9.81]),(3,1))
		
			# take this number of data points for each pose.
			max_data_points_per_pose = 100
			
			# rotation of the force torque sensor in world frame (quaternions)
			ft_rotation = [0,0,0,0]
			
			# loop over all ft sensors to calibrate
			
			for chain in self._calibration_chain: 
				tf_chain=['/pelvis', 'ltorso', 'mtorso', 'utorso']
				if chain == 'left_arm':
					joint_range = l_arm_range
					tfname = 'l_hand'
					tf_chain.extend(['l_clav', 'l_scap', 'l_uarm', 'l_larm', 'l_ufarm', 'l_lfarm', 'l_hand'])
				elif chain == 'right_arm':
					joint_range = r_arm_range
					tf_chain.extend(['r_clav', 'r_scap', 'r_uarm', 'r_larm', 'r_ufarm', 'r_lfarm', 'r_hand'])
					
					tfname = 'r_hand'
				else:
					Logger.logwarn('CalculateForceTorqueCalibration: Undefined chain %s', chain)

				# initialize transformations
				tf_data = [] # frame to frame transformation with numerical indexes from the tf_chain
				tf_data_cum = [] # world to frame transformation with numerical indexes from the tf_chain

				for i in range(len(tf_chain)):
					tf_data.append([0,0,0,0])
					tf_data_cum.append([0,0,0,0])
									
				# get number of poses from the commanded trajectories. 2 Points per Pose
				number_of_poses = len(self._trajectories_command[chain].points) / 2
				# define information matrix for calibration
				InfMat = np.zeros((6*max_data_points_per_pose*number_of_poses, 10))
				InfMat_i = 0 # Index in this Matrix
				# definie measurement vector for calibration
				MeasVec = np.zeros((6*max_data_points_per_pose*number_of_poses, 1))
				
				# commanded joint trajectory for current chain
				# all commanded positions in this trajectory
				current_traj_cmd = self._trajectories_command[chain] 
				
				# flag if the time interval was already set:
				timesetflag = False
				
				point_index_cmd = 0 # start with index 0
				transformation_available = False
				
				# read time series from bag file
				# check in desired input trajectory, at which time a position is reached
				# take the data in a defined time period after the commanded new position
				bag_from_robot = Bag(os.path.expanduser(self._bag_filename)) # open bag file
				Logger.loginfo('CalculateForceTorqueCalibration: Calibrate %s from %s. Using %d different positions from trajectory command. Expecting 2 commanded positions per pose' % (chain, self._bag_filename, len(current_traj_cmd.points)) )
				for topic, msg, t in bag_from_robot.read_messages(topics=['/flor/controller/atlas_state', '/tf', '/flor/controller/joint_command']):
					# Check if all desired poses have been reached
					if point_index_cmd > len(current_traj_cmd.points)-1:
						break
					current_position_cmd = current_traj_cmd.points[point_index_cmd].positions
					####################################################
					# Check tf message: remember last transformation from world to ft sensor
					if topic == '/tf':
			  			# loop over all transformations inside this tf message
						for j in range(len(msg.transforms)):
							data = msg.transforms[j]
							tr = data.transform # current transformation
							header = data.header
							# check if frame matches on of the saved frames
							for i in range(len(tf_chain)):
								if data.child_frame_id == tf_chain[i]:
									tf_data[i] = np.array([tr.rotation.x, tr.rotation.y, tr.rotation.z, tr.rotation.w])

						# calculate the newest transformation to the force torque sensor
						tf_data_cum[0] = tf_data[0]
						for i in range(1,len(tf_chain)):
							tf_data_cum[i] = q = tf.transformations.quaternion_multiply(tf_data_cum[i-1], tf_data[i])
							if i == len(tf_chain)-1:
								time_tf = msg.transforms[0].header.stamp.to_sec()
						# nothing to do with tf message
						# check if all transformations are available. (up to the last link)
						if np.any(tf_data_cum[-1]): # real part of quaternion unequal to zero: data exists
							transformation_available = True
						continue
					else:
						time_msg = msg.header.stamp.to_sec()
					
					####################################################
					# check if timestamp is interesting. Then check the 
					if not (time_msg > time_start and time_msg < time_end) and not timesetflag:
						# the timestamp is not in evaluation interval. Look for reaching of the pose
						# record data (see below)
						if topic == '/flor/controller/joint_command':
				  			# get the commanded position
							for i in range(len(joint_range)):
								osrf_ndx = joint_range[i]
								joint_position_cmd[i] = msg.position[osrf_ndx]
	
				  			# Check if position command matches the currently expected commanded position
				  			pos_reached = True

				  			for i in range(len(joint_range)):
				  				if abs(joint_position_cmd[i]-current_position_cmd[i]) > 0.001:
				  					pos_reached = False
				  					break
		  					if pos_reached:
			  					# end time of the values regarded for calibration data: take the latest time possible  for this pose
			  					time_end = msg.header.stamp.to_sec() + self._settlingtime
			  					# starting time for calibration: Take 100ms
			  					time_start = time_end - 0.1
			  					data_points_for_this_pose = 0
			  					timesetflag = True
			  					# take the next point next time. Each pose consists of two trajectory points (one for reaching, one for settling).
			  					# take the second one
			  					point_index_cmd = point_index_cmd + 2
				  			if not pos_reached:
				  				# The commanded position has not been reached. Skip
				  				continue
			  			continue # continue here, because data aquisition is triggered by the time_start, time_end
			  		
			  		if time_msg > time_end:
			  			timesetflag = False # prepare for new evaluation interval
			  			
					####################################################
					# check if enough datapoints for this pose have been collected
			  		if data_points_for_this_pose > max_data_points_per_pose:
			  			# already enough data points for this pose	
			  			continue
			  		
			  		####################################################
			  		# Check if message is atlas_state
			  		# IF this is the case, fill information matrix
			  		if topic != '/flor/controller/atlas_state':
			  			continue

					# Extract measured force and torque
					if chain == 'left_arm':
						FT = [msg.l_hand.force.x, msg.l_hand.force.y, msg.l_hand.force.z, msg.l_hand.torque.x, msg.l_hand.torque.y, msg.l_hand.torque.z ]
					elif chain == 'right_arm':
						FT = [msg.r_hand.force.x, msg.r_hand.force.y, msg.r_hand.force.z, msg.r_hand.torque.x, msg.r_hand.torque.y, msg.r_hand.torque.z ]
	
					# calculate gravitation vector
					if not transformation_available:
						Logger.logwarn('No tf messages available at time %1.4f.' % time_msg)
						continue
						
					R = tf.transformations.quaternion_matrix(tf_data_cum[-1])
					g = np.dot((R[0:3,0:3]).transpose(), g_world)
	
					gx = g[0]
					gy = g[1]
					gz = g[2]
				
					# fill information matrix for this data point. Source: [1], equ. (7)
					M = np.zeros((6, 10))
					M[0,0] = gx
					M[1,0] = gy
					M[2,0] = gz
					
					M[3,2] = gz
					M[3,3] = -gy 
					M[4,1] = -gz
					M[4,3] = gx
					M[5,1] = gy
					M[5,2] = -gx
					M[0,4] = 1.0
					M[1,5] = 1.0
					M[2,6] = 1.0
					M[3,7] = 1.0
					M[4,8] = 1.0
					M[5,9] = 1.0
					
					# fill big information matrix and vector (stack small information matrizes)
					for i in range(6):
						for j in range(10):
							InfMat[InfMat_i*6+i,j] = M[i,j]
						MeasVec[InfMat_i*6+i,0] = FT[i]
						
					InfMat_i = InfMat_i + 1 # increase index
					data_points_for_this_pose = data_points_for_this_pose + 1
			
				# shorten big information matrix
				if InfMat_i < max_data_points_per_pose*number_of_poses:
					InfMat_calc = InfMat[0:(6*InfMat_i)-1,:]
					MeasVec_calc = MeasVec[0:(6*InfMat_i-1),:]
				else:
					InfMat_calc = InfMat
					MeasVec_calc = MeasVec

				# calculate calibration data
				if chain in self._static_calibration_data.keys(): # calculate calibration with given static parameters
					# bring colums with first parameters on the other side of the equation
					if len(self._static_calibration_data[chain]) != 4:
						Logger.logwarn( "CalculateForceTorqueCalibration: Given static calibration data for %s has length %d. Required 4 entries." % (chain, len(k_fix)) )
						self._failed = True
						return
					# convert physical parameters to identification parameters (mass, 1st moment)
					m = self._static_calibration_data[chain][0]
					if m == 0:
						mom_x = 0
						mom_y = 0
						mom_z = 0
					elif m > 0:
						mom_x = self._static_calibration_data[chain][1]/m
						mom_y = self._static_calibration_data[chain][2]/m
						mom_z = self._static_calibration_data[chain][3]/m
					else:
						Logger.logwarn( "CalculateForceTorqueCalibration: Negative mass (%f) for calibration requested. Abort." % mass )
						self._failed = True
						return
					k_fix = np.resize(np.array([m, mom_x, mom_y, mom_z]),(4,1))
					Logger.loginfo( "CalculateForceTorqueCalibration:static calibration data for %s given: %s. Reduce equation system" % (chain, str(self._static_calibration_data[chain])) )
					MeasVec_calc_corr = np.subtract(np.array(MeasVec_calc), np.dot(InfMat_calc[:,0:4], k_fix))
					InfMat_calc_corr = InfMat_calc[:,4:10] # only the last 6 colums which correspond to the sensor offsets
					k = np.linalg.lstsq(InfMat_calc_corr, MeasVec_calc_corr)[0] # solve reduced equation system
					k_calibration = self._static_calibration_data[chain]
					k_calibration.extend(k)
				else: # calculate normally with all parameters unknown
					k = np.linalg.lstsq(InfMat_calc, MeasVec_calc)[0]
					# convert to physical parameters (first moment -> mass)
					if k[0] > 0:
						k_calibration = [k[0], k[1]/k[0], k[2]/k[0], k[3]/k[0], k[4], k[5], k[6], k[7], k[8], k[9]]
					else:
						Logger.loginfo("CalculateForceTorqueCalibration:Calibration brought negative mass %f" % k[0])
						k_calibration = [0.0, 0.0, 0.0, 0.0, k[4], k[5], k[6], k[7], k[8], k[9]]
					
				Logger.loginfo("CalculateForceTorqueCalibration:calibration data for %s" % chain)
				Logger.loginfo("CalculateForceTorqueCalibration:mass: %f" % float(k_calibration[0]))
				Logger.loginfo("CalculateForceTorqueCalibration:center of mass: %f %f %f" % (k_calibration[1], k_calibration[2], k_calibration[3]))
				Logger.loginfo("CalculateForceTorqueCalibration:F offset: %f %f %f" % (k_calibration[4], k_calibration[5], k_calibration[6]))
				Logger.loginfo("CalculateForceTorqueCalibration:M offset: %f %f %f" % (k_calibration[7], k_calibration[8], k_calibration[9]))
				self._ft_calib_data[chain] = k_calibration
			
				bag_from_robot.close()
	
			userdata.ft_calib_data = self._ft_calib_data
			Logger.loginfo('CalculateForceTorqueCalibration:Calibration finished')
			self._done = True

		except Exception as e:
			Logger.logwarn('CalculateForceTorqueCalibration:Unable to calculate calibration:\n%s' % str(e))
			self._failed = True
예제 #21
0
class Rosbag(ForeignDataWrapper):
    def __init__(self, options, columns=None):
        super(Rosbag, self).__init__(options, columns)
        Bag = import_bag(options)
        self.filename = options.pop('rosbag_path', "") + options.pop('rosbag')
        self.topic = options.pop('topic', None)
        pointcloud_formats = strtobool(options.pop('metadata', 'false'))

        self.patch_column = options.pop('patch_column', 'points').strip()
        self.patch_columns = options.pop('patch_columns', '*').strip()
        self.patch_columns = [
            col.strip() for col in self.patch_columns.split(',')
            if col.strip()
        ]

        self.patch_count_default = int(options.pop('patch_count_default',
                                                   1000))
        # 0 => 1 patch per message
        self.patch_count_pointcloud = int(
            options.pop('patch_count_pointcloud', 0))
        assert (self.patch_count_default > 0)
        assert (self.patch_count_pointcloud >= 0)
        self.pcid = int(options.pop('pcid', 0))
        self.bag = Bag(self.filename, 'r')
        self.topics = self.bag.get_type_and_topic_info().topics
        self.pointcloud_formats = None

        if pointcloud_formats:
            self.pointcloud_formats = []
            topics = self.topic.split(',') if self.topic else self.topics
            for i, topic in enumerate(self.topics):
                if topic not in topics:
                    continue
                infos = self.topics[topic]
                columns, patch_schema, patch_ply_header, _, patch_columns, patch_srid = \
                    get_columns(self.bag, topic, infos, self.pcid+i+1, self.patch_column,
                                self.patch_columns)
                self.pointcloud_formats.append({
                    'pcid':
                    self.pcid + i + 1,
                    'srid':
                    patch_srid,
                    'schema':
                    patch_schema,
                    'format':
                    columns[self.patch_column][3],
                    'rostype':
                    infos.msg_type,
                    'columns':
                    patch_columns,
                    'ply_header':
                    patch_ply_header,
                })
            return

        if not self.topic:
            log_to_postgres('"topic" option is required', ERROR)

        self.infos = self.topics[self.topic]
        (self.columns, self.patch_schema, self.patch_ply_header, self.endianness,
         self.patch_columns, self.patch_srid) = \
            get_columns(self.bag, self.topic, self.infos, self.pcid, self.patch_column,
                        self.patch_columns)

        if columns:
            missing = set(columns) - set(self.columns.keys())
            columns = list(c for c in self.columns.keys() if c in columns)
            if missing:
                missing = ", ".join(sorted(missing))
                support = ", ".join(sorted(self.columns.keys()))
                log_to_postgres(
                    "extra unsupported columns : {}".format(missing),
                    WARNING,
                    hint="supported columns : {}".format(support))
            self.columns = {col: self.columns[col] for col in columns}

        if options:
            log_to_postgres(
                "extra unsupported options : {}".format(options.keys()),
                WARNING)

    @classmethod
    def import_schema(self, schema, srv_options, options, restriction_type,
                      restricts):
        Bag = import_bag(srv_options)
        pcid_str = options.pop('pcid', srv_options.pop('pcid', 0))
        pcid = int(pcid_str)
        patch_column = options.pop('patch_column',
                                   srv_options.pop('patch_column', 'points'))
        patch_columns = options.pop('patch_columns', '*').strip()
        patch_columns = [
            col.strip() for col in patch_columns.split(',') if col.strip()
        ]
        filename = srv_options.pop('rosbag_path', "") + options.pop(
            'rosbag_path', "") + schema
        bag = Bag(filename, 'r')

        tablecols = []
        topics = bag.get_type_and_topic_info().topics
        pcid_for_topic = {k: pcid + 1 + i for i, k in enumerate(topics.keys())}
        pointcloud_formats = True
        if restriction_type is 'limit':
            topics = {k: v for k, v in topics.items() if k in restricts}
            pointcloud_formats = 'pointcloud_formats' in restricts
        elif restriction_type is 'except':
            topics = {k: v for k, v in topics.items() if k not in restricts}
            pointcloud_formats = 'pointcloud_formats' not in restricts

        tabledefs = []
        if pointcloud_formats:
            tablecols = [
                ColumnDefinition('pcid', type_name='integer'),
                ColumnDefinition('srid', type_name='integer'),
                ColumnDefinition('schema', type_name='text'),
                ColumnDefinition('format', type_name='text'),
                ColumnDefinition('rostype', type_name='text'),
                ColumnDefinition('columns', type_name='text[]'),
                ColumnDefinition('ply_header', type_name='text'),
            ]
            tableopts = {
                'metadata': 'true',
                'rosbag': schema,
                'pcid': pcid_str
            }
            tabledefs.append(
                TableDefinition("pointcloud_formats",
                                columns=tablecols,
                                options=tableopts))

        for topic, infos in topics.items():
            pcid = pcid_for_topic[topic]
            columns, _, _, _, _, _ = get_columns(bag, topic, infos, pcid,
                                                 patch_column, patch_columns)
            tablecols = [get_column_def(k, *v) for k, v in columns.items()]
            tableopts = {'topic': topic, 'rosbag': schema, 'pcid': str(pcid)}
            tabledefs.append(
                TableDefinition(topic, columns=tablecols, options=tableopts))
        return tabledefs

    def execute(self, quals, columns):
        if self.pointcloud_formats is not None:
            for f in self.pointcloud_formats:
                yield f
            return
        self.patch_data = ''
        from rospy.rostime import Time
        tmin = None
        tmax = None
        for qual in quals:
            if qual.field_name == "time":
                t = int(qual.value)
                t = Time(t / 1000000000, t % 1000000000)
                if qual.operator in ['=', '>', '>=']:
                    tmin = t
                if qual.operator in ['=', '<', '<=']:
                    tmax = t
        for topic, msg, t in self.bag.read_messages(topics=self.topic,
                                                    start_time=tmin,
                                                    end_time=tmax):
            for row in self.get_rows(topic, msg, t, columns):
                yield row

        # flush leftover patch data
        if self.patch_data and self.last_row:
            count = int((len(self.patch_data) / self.point_size))
            # in replicating mode, a single leftover point must not be reported
            if count > 1 or self.patch_step_size == self.patch_size:
                res = self.last_row
                if self.patch_column in columns:
                    res[self.patch_column] = hexlify(
                        pack('=b3I', self.endianness, self.pcid, 0, count) +
                        self.patch_data)
                if self.patch_ply_header and 'ply' in columns:
                    self.ply_info['count'] = count
                    res['ply'] = self.patch_ply_header.format(
                        **self.ply_info) + self.patch_data
                yield res

    def get_rows(self, topic, msg, t, columns, toplevel=True):
        if toplevel and len(msg.__slots__) == 1:
            attr = getattr(msg, msg.__slots__[0])
            if isinstance(attr, list):
                for msg in attr:
                    for row in self.get_rows(topic, msg, t, columns, False):
                        yield row
                return
        res = {}
        data_columns = set(columns)
        if self.patch_column in columns:
            data_columns = data_columns.union(self.patch_columns) - set(
                [self.patch_column])
        if "filename" in data_columns:
            res["filename"] = self.filename
        if "topic" in data_columns:
            res["topic"] = topic
        if "time" in data_columns:
            res["time"] = t.to_nsec()
        if self.infos.msg_type == 'sensor_msgs/PointCloud2':
            self.patch_count = self.patch_count_pointcloud or (msg.width *
                                                               msg.height)
            self.point_size = msg.point_step
            self.patch_size = self.patch_count * self.point_size
            self.patch_step_size = self.patch_size
            self.endianness = 0 if msg.is_bigendian else 1
            data_columns = data_columns - set(['ply', self.patch_column])
            self.patch_data += msg.data

        data_columns = data_columns - set(res.keys())
        for column in data_columns:
            attr = msg
            for col in column.split('.'):
                if isinstance(attr, list):
                    attr = tuple(getattr(a, col) for a in attr)
                else:
                    attr = getattr(attr, col)
            if hasattr(attr, "to_nsec"):
                attr = attr.to_nsec()
            elif hasattr(attr, "x"):
                if hasattr(attr, "w"):
                    attr = (attr.x, attr.y, attr.z, attr.w)
                else:
                    attr = (attr.x, attr.y, attr.z)
            elif isinstance(attr, str):
                fmt = self.columns[column][3]
                if fmt:
                    attr = unpack(fmt, attr)
            res[column] = attr

        if self.patch_column in columns and not self.infos.msg_type == 'sensor_msgs/PointCloud2':
            fmt = self.columns[self.patch_column][3]
            self.patch_count = self.patch_count_default
            self.point_size = calcsize(fmt)
            self.patch_size = self.patch_count * self.point_size
            self.patch_step_size = self.patch_size - self.point_size
            self.patch_data += get_point_data(res, self.patch_columns, fmt)
            res = {k: v for k, v in res.items() if k not in self.patch_columns}

        if not self.patch_data:
            yield res
        else:
            # todo: ensure current res and previous res are equal if there is some leftover
            # patch_data
            while len(self.patch_data) >= self.patch_size:
                data = self.patch_data[0:self.patch_size]
                count = int(self.patch_size / self.point_size)
                res[self.patch_column] = hexlify(
                    pack('=b3I', self.endianness, self.pcid, 0, count) + data)
                if self.patch_ply_header and 'ply' in columns:
                    self.ply_info = {
                        'endianness': 'big' if self.endianness else 'little',
                        'filename': self.filename,
                        'topic': self.topic,
                        'count': count
                    }
                    res['ply'] = self.patch_ply_header.format(
                        **self.ply_info) + data
                self.patch_data = self.patch_data[self.patch_step_size:]
                yield res
            self.last_row = res
예제 #22
0
def semantic_segment(
    metadata: str,
    input_bag: Bag,
    model: 'keras.models.Model',
    predict: str,
    output_bag: Bag = None,
    output_dir: str = None,
    base: str = None,
    num_samples: int = 200,
    encoding: str = 'rgb',
) -> None:
    """
    Predict a stream of images from an input ROSbag.

    Args:
        metadata: the metadata about the semantic segmentations from the model
        input_bag: the input bag to predict targets from a topic
        model: the semantic segmentation model to use to make predictions
        predict: the topic to get a priori estimates from
        output_bag: the output bag to write the a priori estimates to
        output_dir: the output directory to write image pairs to
        base: the base-name for the prediction image topic
        num_samples: the number of image pairs to sample for output directory
        encoding: the encoding for the images to write

    Returns:
        None

    """
    # create the base endpoint for the topics
    base = '' if base is None else '{}'.format(base)

    # setup the output directories
    if output_dir is not None:
        x_dir = os.path.join(output_dir, 'X', 'data')
        if not os.path.isdir(x_dir):
            os.makedirs(x_dir)
        y_dir = os.path.join(output_dir, 'y', 'data')
        if not os.path.isdir(y_dir):
            os.makedirs(y_dir)

    # read the RGB map and vectorized method from the metadata file
    rgb_map, unmap_rgb = read_rgb_map(metadata)

    # write the color map metadata to the output bag
    if output_bag is not None:
        ros_stamp = rospy.rostime.Time(input_bag.get_start_time())
        msg = String(repr(rgb_map))
        output_bag.write('{}/rgb_map'.format(rgb_map), msg, ros_stamp)

    # open a Window to play the video
    x_window = Window('img', model.input_shape[1], model.input_shape[2])
    y_window = Window('sem-seg', model.output_shape[1], model.output_shape[2])
    # create a progress bar for iterating over the messages in the bag
    total_messages = input_bag.get_message_count(topic_filters=predict)
    with tqdm(total=total_messages, unit='message') as prog:
        # iterate over the messages in this input bag
        for _, msg, time in input_bag.read_messages(topics=predict):
            # update the progress bar with a single iteration
            prog.update(1)
            if np.random.random() > num_samples / total_messages:
                continue
            # create a tensor from the raw pixel data
            pixels = get_camera_image(msg.data,
                                      (msg.height, msg.width))[..., :3]
            # flip the BGR image to RGB
            if encoding == 'bgr':
                pixels = pixels[..., ::-1]
            # resize the pixels to the shape of the model
            _pixels = resize(
                pixels,
                model.input_shape[1:],
                anti_aliasing=False,
                mode='symmetric',
                clip=False,
                preserve_range=True,
            ).astype('uint8')
            # pass the frame through the model
            y_pred = model.predict(_pixels[None, ...])[0]
            y_pred = np.stack(unmap_rgb(y_pred.argmax(axis=-1)), axis=-1)
            y_pred = y_pred.astype('uint8')
            # show the pixels on the windows
            x_window.show(_pixels)
            y_window.show(y_pred)
            # create an Image message and write it to the output ROSbag
            if output_bag is not None:
                msg = image_msg(y_pred, msg.header.stamp, y_pred.shape[:2],
                                'rgb8')
                output_bag.write('{}/image_raw'.format(base), msg,
                                 msg.header.stamp)
            # sample a number and write the image pair to disk
            if output_dir is not None:
                x_file = os.path.join(x_dir, '{}.png'.format(time))
                Image.fromarray(pixels).save(x_file)
                y_file = os.path.join(y_dir, '{}.png'.format(time))
                y_pred = resize(
                    y_pred,
                    pixels.shape[:2],
                    anti_aliasing=False,
                    mode='symmetric',
                    clip=False,
                    preserve_range=True,
                ).astype('uint8')
                Image.fromarray(y_pred).save(y_file)
예제 #23
0
def main():
    parser = argparse.ArgumentParser(
        description=("Extracts grayscale and event images from a ROS bag and "
                     "saves them as TFRecords for training in TensorFlow."))
    parser.add_argument("--bag",
                        dest="bag",
                        help="Path to ROS bag.",
                        required=True)
    parser.add_argument("--prefix",
                        dest="prefix",
                        help="Output file prefix.",
                        required=True)
    parser.add_argument("--output_folder",
                        dest="output_folder",
                        help="Output folder.",
                        required=True)
    parser.add_argument(
        "--max_aug",
        dest="max_aug",
        help="Maximum number of images to combine for augmentation.",
        type=int,
        default=6)
    parser.add_argument(
        "--n_skip",
        dest="n_skip",
        help="Maximum number of images to combine for augmentation.",
        type=int,
        default=1)
    parser.add_argument("--start_time",
                        dest="start_time",
                        help="Time to start in the bag.",
                        type=float,
                        default=0.0)
    parser.add_argument("--end_time",
                        dest="end_time",
                        help="Time to end in the bag.",
                        type=float,
                        default=-1.0)

    args = parser.parse_args()

    bridge = CvBridge()

    n_msgs = 0
    left_event_image_iter = 0
    right_event_image_iter = 0
    left_image_iter = 0
    right_image_iter = 0
    first_left_image_time = -1
    first_right_image_time = -1

    left_events = []
    right_events = []
    left_images = []
    right_images = []
    left_image_times = []
    right_image_times = []
    left_event_count_images = []
    left_event_time_images = []
    left_event_image_times = []

    right_event_count_images = []
    right_event_time_images = []
    right_event_image_times = []

    cols = 346
    rows = 260
    print("Processing bag")
    bag = Bag(args.bag)

    left_tf_writer = tf.python_io.TFRecordWriter(
        os.path.join(args.output_folder, args.prefix,
                     "left_event_images.tfrecord"))
    right_tf_writer = tf.python_io.TFRecordWriter(
        os.path.join(args.output_folder, args.prefix,
                     "right_event_images.tfrecord"))

    # Get actual time for the start of the bag.
    t_start = bag.get_start_time()
    t_start_ros = rospy.Time(t_start)
    # Set the time at which the bag reading should end.
    if args.end_time == -1.0:
        t_end = bag.get_end_time()
    else:
        t_end = t_start + args.end_time

    eps = 0.1
    for topic, msg, t in bag.read_messages(
            topics=[
                '/davis/left/image_raw', '/davis/right/image_raw',
                '/davis/left/events', '/davis/right/events'
            ],
            start_time=rospy.Time(max(args.start_time, eps) - eps + t_start),
            end_time=rospy.Time(t_end)):
        # Check to make sure we're working with stereo messages.
        if not ('left' in topic or 'right' in topic):
            print(
                'ERROR: topic {} does not contain left or right, is this stereo?'
                'If not, you will need to modify the topic names in the code.'.
                format(topic))
            return

        # Counter for status updates.
        n_msgs += 1
        if n_msgs % 500 == 0:
            print("Processed {} msgs, {} images, time is {}.".format(
                n_msgs, left_event_image_iter,
                t.to_sec() - t_start))

        isLeft = 'left' in topic
        if 'image' in topic:
            width = msg.width
            height = msg.height
            if width != cols or height != rows:
                print(
                    "Image dimensions are not what we expected: set: ({} {}) vs  got:({} {})"
                    .format(cols, rows, width, height))
                return
            time = msg.header.stamp
            if time.to_sec() - t_start < args.start_time:
                continue
            image = np.asarray(bridge.imgmsg_to_cv2(msg, msg.encoding))
            image = np.reshape(image, (height, width))

            if isLeft:
                cv2.imwrite(
                    os.path.join(
                        args.output_folder, args.prefix,
                        "left_image{:05d}.png".format(left_image_iter)), image)
                if left_image_iter > 0:
                    left_image_times.append(time)
                else:
                    first_left_image_time = time
                    left_event_image_times.append(time.to_sec())
                    # filter events we added previously
                    left_events = filter_events(
                        left_events, left_event_image_times[-1] - t_start)
                left_image_iter += 1
            else:
                cv2.imwrite(
                    os.path.join(
                        args.output_folder, args.prefix,
                        "right_image{:05d}.png".format(right_image_iter)),
                    image)
                if right_image_iter > 0:
                    right_image_times.append(time)
                else:
                    first_right_image_time = time
                    right_event_image_times.append(time.to_sec())
                    # filter events we added previously
                    right_events = filter_events(
                        right_events, right_event_image_times[-1] - t_start)

                right_image_iter += 1
        elif 'events' in topic and msg.events:
            # Add events to list.
            for event in msg.events:
                ts = event.ts
                event = [
                    event.x, event.y, (ts - t_start_ros).to_sec(),
                    (float(event.polarity) - 0.5) * 2
                ]
                if isLeft:
                    # add event if it was after the first image or we haven't seen the first image
                    if first_left_image_time == -1 or ts > first_left_image_time:
                        left_events.append(event)
                elif first_right_image_time == -1 or ts > first_right_image_time:
                    right_events.append(event)
            if isLeft:
                if len(left_image_times) >= args.max_aug and\
                   left_events[-1][2] > (left_image_times[args.max_aug-1]-t_start_ros).to_sec():
                    left_event_image_iter = _save_events(
                        left_events, left_image_times, left_event_count_images,
                        left_event_time_images, left_event_image_times, rows,
                        cols, args.max_aug, args.n_skip, left_event_image_iter,
                        args.prefix, 'left', left_tf_writer, t_start_ros)
            else:
                if len(right_image_times) >= args.max_aug and\
                   right_events[-1][2] > (right_image_times[args.max_aug-1]-t_start_ros).to_sec():
                    right_event_image_iter = _save_events(
                        right_events, right_image_times,
                        right_event_count_images, right_event_time_images,
                        right_event_image_times, rows, cols, args.max_aug,
                        args.n_skip, right_event_image_iter, args.prefix,
                        'right', right_tf_writer, t_start_ros)

    left_tf_writer.close()
    right_tf_writer.close()

    image_counter_file = open(
        os.path.join(args.output_folder, args.prefix, "n_images.txt"), 'w')
    image_counter_file.write("{} {}".format(left_event_image_iter,
                                            right_event_image_iter))
    image_counter_file.close()
예제 #24
0
class Rosbag(ForeignDataWrapper):
    def __init__(self, options, columns=None):
        super(Rosbag, self).__init__(options, columns)
        Bag = import_bag(options)
        self.filename = options.pop('rosbag_path', "") + options.pop('rosbag')
        self.topic = options.pop('topic', None)
        pointcloud_formats = strtobool(options.pop('metadata', 'false'))

        self.patch_column = options.pop('patch_column', 'points').strip()
        self.patch_columns = options.pop('patch_columns', '*').strip()
        self.patch_columns = [col.strip() for col in self.patch_columns.split(',') if col.strip()]

        self.patch_count_default = int(options.pop('patch_count_default', 1000))
        # 0 => 1 patch per message
        self.patch_count_pointcloud = int(options.pop('patch_count_pointcloud', 0))
        assert(self.patch_count_default > 0)
        assert(self.patch_count_pointcloud >= 0)
        self.pcid = int(options.pop('pcid', 0))
        self.bag = Bag(self.filename, 'r')
        self.topics = self.bag.get_type_and_topic_info().topics
        self.pointcloud_formats = None

        if pointcloud_formats:
            self.pointcloud_formats = []
            topics = self.topic.split(',') if self.topic else self.topics
            for i, topic in enumerate(self.topics):
                if topic not in topics:
                    continue
                infos = self.topics[topic]
                columns, patch_schema, patch_ply_header, _, patch_columns, patch_srid = \
                    get_columns(self.bag, topic, infos, self.pcid+i+1, self.patch_column,
                                self.patch_columns)
                self.pointcloud_formats.append({
                    'pcid': self.pcid+i+1,
                    'srid': patch_srid,
                    'schema': patch_schema,
                    'format': columns[self.patch_column][3],
                    'rostype': infos.msg_type,
                    'columns': patch_columns,
                    'ply_header': patch_ply_header,
                })
            return

        if not self.topic:
            log_to_postgres('"topic" option is required', ERROR)

        self.infos = self.topics[self.topic]
        (self.columns, self.patch_schema, self.patch_ply_header, self.endianness,
         self.patch_columns, self.patch_srid) = \
            get_columns(self.bag, self.topic, self.infos, self.pcid, self.patch_column,
                        self.patch_columns)

        if columns:
            missing = set(columns) - set(self.columns.keys())
            columns = list(c for c in self.columns.keys() if c in columns)
            if missing:
                missing = ", ".join(sorted(missing))
                support = ", ".join(sorted(self.columns.keys()))
                log_to_postgres(
                    "extra unsupported columns : {}".format(missing), WARNING,
                    hint="supported columns : {}".format(support))
            self.columns = {col: self.columns[col] for col in columns}

        if options:
            log_to_postgres("extra unsupported options : {}".format(
                options.keys()), WARNING)

    @classmethod
    def import_schema(self, schema, srv_options, options,
                      restriction_type, restricts):
        Bag = import_bag(srv_options)
        pcid_str = options.pop('pcid', srv_options.pop('pcid', 0))
        pcid = int(pcid_str)
        patch_column = options.pop('patch_column', srv_options.pop('patch_column', 'points'))
        patch_columns = options.pop('patch_columns', '*').strip()
        patch_columns = [col.strip() for col in patch_columns.split(',') if col.strip()]
        filename = srv_options.pop('rosbag_path', "") + options.pop('rosbag_path', "") + schema
        bag = Bag(filename, 'r')

        tablecols = []
        topics = bag.get_type_and_topic_info().topics
        pcid_for_topic = {k: pcid+1+i for i, k in enumerate(topics.keys())}
        pointcloud_formats = True
        if restriction_type is 'limit':
            topics = {k: v for k, v in topics.items() if k in restricts}
            pointcloud_formats = 'pointcloud_formats' in restricts
        elif restriction_type is 'except':
            topics = {k: v for k, v in topics.items() if k not in restricts}
            pointcloud_formats = 'pointcloud_formats' not in restricts

        tabledefs = []
        if pointcloud_formats:
            tablecols = [
                ColumnDefinition('pcid', type_name='integer'),
                ColumnDefinition('srid', type_name='integer'),
                ColumnDefinition('schema', type_name='text'),
                ColumnDefinition('format', type_name='text'),
                ColumnDefinition('rostype', type_name='text'),
                ColumnDefinition('columns', type_name='text[]'),
                ColumnDefinition('ply_header', type_name='text'),
            ]
            tableopts = {'metadata': 'true', 'rosbag': schema, 'pcid': pcid_str}
            tabledefs.append(TableDefinition("pointcloud_formats", columns=tablecols,
                                             options=tableopts))

        for topic, infos in topics.items():
            pcid = pcid_for_topic[topic]
            columns, _, _, _, _, _ = get_columns(bag, topic, infos, pcid,
                                                 patch_column, patch_columns)
            tablecols = [get_column_def(k, *v) for k, v in columns.items()]
            tableopts = {'topic': topic, 'rosbag': schema, 'pcid': str(pcid)}
            tabledefs.append(TableDefinition(topic, columns=tablecols, options=tableopts))
        return tabledefs

    def execute(self, quals, columns):
        if self.pointcloud_formats is not None:
            for f in self.pointcloud_formats:
                yield f
            return
        self.patch_data = ''
        from rospy.rostime import Time
        tmin = None
        tmax = None
        for qual in quals:
            if qual.field_name == "time":
                t = int(qual.value)
                t = Time(t / 1000000000, t % 1000000000)
                if qual.operator in ['=', '>', '>=']:
                    tmin = t
                if qual.operator in ['=', '<', '<=']:
                    tmax = t
        for topic, msg, t in self.bag.read_messages(
                topics=self.topic, start_time=tmin, end_time=tmax):
            for row in self.get_rows(topic, msg, t, columns):
                yield row

        # flush leftover patch data
        if self.patch_data and self.last_row:
            count = int((len(self.patch_data) / self.point_size))
            # in replicating mode, a single leftover point must not be reported
            if count > 1 or self.patch_step_size == self.patch_size:
                res = self.last_row
                if self.patch_column in columns:
                    res[self.patch_column] = hexlify(
                            pack('=b3I', self.endianness, self.pcid, 0, count) + self.patch_data)
                if self.patch_ply_header and 'ply' in columns:
                    self.ply_info['count'] = count
                    res['ply'] = self.patch_ply_header.format(**self.ply_info) + self.patch_data
                yield res

    def get_rows(self, topic, msg, t, columns, toplevel=True):
        if toplevel and len(msg.__slots__) == 1:
            attr = getattr(msg, msg.__slots__[0])
            if isinstance(attr, list):
                for msg in attr:
                    for row in self.get_rows(topic, msg, t, columns, False):
                        yield row
                return
        res = {}
        data_columns = set(columns)
        if self.patch_column in columns:
            data_columns = data_columns.union(self.patch_columns) - set([self.patch_column])
        if "filename" in data_columns:
            res["filename"] = self.filename
        if "topic" in data_columns:
            res["topic"] = topic
        if "time" in data_columns:
            res["time"] = t.to_nsec()
        if self.infos.msg_type == 'sensor_msgs/PointCloud2':
            self.patch_count = self.patch_count_pointcloud or (msg.width*msg.height)
            self.point_size = msg.point_step
            self.patch_size = self.patch_count * self.point_size
            self.patch_step_size = self.patch_size
            self.endianness = 0 if msg.is_bigendian else 1
            data_columns = data_columns - set(['ply', self.patch_column])
            self.patch_data += msg.data

        data_columns = data_columns - set(res.keys())
        for column in data_columns:
            attr = msg
            for col in column.split('.'):
                if isinstance(attr, list):
                    attr = tuple(getattr(a, col) for a in attr)
                else:
                    attr = getattr(attr, col)
            if hasattr(attr, "to_nsec"):
                attr = attr.to_nsec()
            elif hasattr(attr, "x"):
                if hasattr(attr, "w"):
                    attr = (attr.x, attr.y, attr.z, attr.w)
                else:
                    attr = (attr.x, attr.y, attr.z)
            elif isinstance(attr, str):
                fmt = self.columns[column][3]
                if fmt:
                    attr = unpack(fmt, attr)
            res[column] = attr

        if self.patch_column in columns and not self.infos.msg_type == 'sensor_msgs/PointCloud2':
            fmt = self.columns[self.patch_column][3]
            self.patch_count = self.patch_count_default
            self.point_size = calcsize(fmt)
            self.patch_size = self.patch_count * self.point_size
            self.patch_step_size = self.patch_size - self.point_size
            self.patch_data += get_point_data(res, self.patch_columns, fmt)
            res = {k: v for k, v in res.items() if k not in self.patch_columns}

        if not self.patch_data:
            yield res
        else:
            # todo: ensure current res and previous res are equal if there is some leftover
            # patch_data
            while len(self.patch_data) >= self.patch_size:
                data = self.patch_data[0:self.patch_size]
                count = int(self.patch_size / self.point_size)
                res[self.patch_column] = hexlify(
                        pack('=b3I', self.endianness, self.pcid, 0, count) + data)
                if self.patch_ply_header and 'ply' in columns:
                    self.ply_info = {
                        'endianness': 'big' if self.endianness else 'little',
                        'filename': self.filename,
                        'topic': self.topic,
                        'count': count
                    }
                    res['ply'] = self.patch_ply_header.format(**self.ply_info) + data
                self.patch_data = self.patch_data[self.patch_step_size:]
                yield res
            self.last_row = res
예제 #25
0
def data_collect(file_path, link_total):
    def read_msg(bag, topic, parser): 
        # make list with same name of "topic" at "bag" file, extract data with "parser" function from msg.
        return np.array([
            parser(time, msg)
            for (topic_, msg, time) in bag
            if topic_ == topic
        ])
    bag=Bag(file_path)
    bag = [
        (topic, msg, time)
        for (topic, msg, time)
        in bag.read_messages()
    ]
    # bag is an iterable of (topic, msg, time)
    start_time=bag[0][2].to_sec()
    def start_time_zero(extract_data):
        # make start time data to zero
        buff=extract_data.T
        buff[0]=buff[0]-start_time
        return buff.T
    def quat2rpy(extract_data):
        # make quaternian to roll, pitch, yaw
        buff=np.zeros([extract_data.shape[0], 4])
        i=0
        for (t, w, x,y,z) in extract_data:
            buff[i,0]=t
            r = R.from_quat([w, x, y, z])
            buff[i,1:]=r.as_euler('zyx', degrees=True)
            #print(buff[i,1:])
            i=i+1
        return buff

    position_list=[]
    des_position_list=[]
    rotation_list=[]
    des_rotation_list=[]

    for i in range(1,link_total+1):
        position = read_msg(
            bag, '/mavros_' + str(i) +'/odar/pose',
            parser=lambda t, msg: (t.to_sec(), msg.pose.position.x, msg.pose.position.y,  msg.pose.position.z)
            )
        position=start_time_zero(position)
        position_list.append(position)

        des_position = read_msg(
            bag, '/mavros_' + str(i) +'/odar/desired_pose',
            parser=lambda t, msg: (t.to_sec(), msg.pose.position.x, msg.pose.position.y,  msg.pose.position.z)
            )
        des_position = start_time_zero(des_position)
        des_position_list.append(des_position)

        rotation = read_msg(
            bag, '/mavros_' + str(i) +'/odar/pose',
            parser=lambda t, msg: (t.to_sec(), msg.pose.orientation.w, msg.pose.orientation.x, msg.pose.position.y,  msg.pose.position.z)
            )
        rotation = quat2rpy(start_time_zero(rotation))
        rotation_list.append(rotation)

        des_rotation = read_msg(
            bag, '/mavros_' + str(i) +'/odar/desired_pose',
            parser=lambda t, msg: (t.to_sec(), msg.pose.orientation.w, msg.pose.orientation.x, msg.pose.position.y,  msg.pose.position.z)
            )
        des_rotation = quat2rpy(start_time_zero(des_rotation))
        des_rotation_list.append(des_rotation)

    return position_list, des_position_list, rotation_list, des_rotation_list
예제 #26
0
def parse_bag(bag_file_path):
    bagfile = Bag(bag_file_path)
    msgs_topics = [
        "/drone/pose", "/gazebo/ground_truth", "/mavros/local_position/odom",
        "/mavros/px4flow/ground_distance"
    ]
    msgs_drone_pose = {
        "y": [],
        "z": [],
        "x": [],
        "roll": [],
        "pitch": [],
        "yaw": [],
        "time": [],
        "label": "EKF-SLAM"
    }
    msgs_ground_truth = {
        "y": [],
        "z": [],
        "x": [],
        "roll": [],
        "pitch": [],
        "yaw": [],
        "time": [],
        "label": "Ground Truth"
    }
    msgs_mavros_odom = {
        "y": [],
        "z": [],
        "x": [],
        "roll": [],
        "pitch": [],
        "yaw": [],
        "time": [],
        "label": "MAVROS"
    }
    msgs_mavros_laser = {"range": [], "time": [], "label": "Laser Range"}

    for topic, message, time in bagfile.read_messages(msgs_topics):
        if topic == msgs_topics[0]:
            msgs_drone_pose["x"].append(message.pose.pose.position.x)
            msgs_drone_pose["y"].append(message.pose.pose.position.y)
            msgs_drone_pose["z"].append(message.pose.pose.position.z)

            quat = message.pose.pose.orientation
            euler = quat2euler([quat.x, quat.y, quat.z, quat.w])
            msgs_drone_pose["roll"].append(euler[0])
            msgs_drone_pose["pitch"].append(euler[1])
            msgs_drone_pose["yaw"].append(euler[2])

            msgs_drone_pose["time"].append(time.to_sec())
        if topic == msgs_topics[1]:
            msgs_ground_truth["x"].append(message.pose.pose.position.x)
            msgs_ground_truth["y"].append(message.pose.pose.position.y)
            msgs_ground_truth["z"].append(message.pose.pose.position.z)

            quat = message.pose.pose.orientation
            euler = quat2euler([quat.x, quat.y, quat.z, quat.w])
            msgs_ground_truth["roll"].append(euler[0])
            msgs_ground_truth["pitch"].append(euler[1])
            msgs_ground_truth["yaw"].append(euler[2])

            msgs_ground_truth["time"].append(time.to_sec())
        if topic == msgs_topics[2]:
            msgs_mavros_odom["x"].append(message.pose.pose.position.x - 3)
            msgs_mavros_odom["y"].append(message.pose.pose.position.y)
            msgs_mavros_odom["z"].append(message.pose.pose.position.z)

            quat = message.pose.pose.orientation
            euler = quat2euler([quat.x, quat.y, quat.z, quat.w])
            msgs_mavros_odom["roll"].append(euler[0])
            msgs_mavros_odom["pitch"].append(euler[1])
            msgs_mavros_odom["yaw"].append(euler[2])

            msgs_mavros_odom["time"].append(time.to_sec())
        if topic == msgs_topics[3]:
            msgs_mavros_laser["range"].append(message.range)
            msgs_mavros_laser["time"].append(time.to_sec())

    bagfile.close()
    return {
        "ground_truth": msgs_ground_truth,
        "mavros_odom": msgs_mavros_odom,
        "mavros_laser": msgs_mavros_laser,
        "ekf2": msgs_drone_pose
    }
예제 #27
0
def create_ground_truth(path,
                        time_offset=0,
                        bagfile='.logs/*.bag',
                        topic='/StereoTUM/estimated_transform',
                        file=''):

    from rosbag import Bag
    from geometry_msgs.msg import TransformStamped
    from glob import glob as findfiles

    if not file: file = p.join(path, 'data', 'ground_truth.csv')

    first = None
    msgs = []
    timefiles = findfiles(p.join(path, 'params', 'time.yaml'))
    if not timefiles:
        raise ValueError('Can\'t find %s/params/time.yaml' % path)
        return

    seq_stamps = (None, None)
    with open(timefiles[0]) as stream:
        time = yaml.load(stream)
        seq_stamps = (time['time']['start'], time['time']['end'])

    print('Using time offset: %fs' % time_offset)

    a = p.join(path, bagfile)
    b = findfiles(a)
    if not b:
        raise ValueError('Can\'t find the bag file %s' % a)
        return
    bag = Bag(b[0])
    print('Found ground truth file at: %s' % b[0])
    print('Will listen to topic %s' % topic)

    for bagtopic, msg, t in bag.read_messages():

        if bagtopic != topic: continue

        now = t.to_sec() + time_offset
        if now < seq_stamps[0]: continue  # before recording started
        if first is None: first = now
        if now > seq_stamps[1]: continue  # after recording finished

        stamp = now - first

        # Check if any message with the same timestamp has already been added...
        if any(item.get('time', None) == stamp for item in msgs): continue

        # If not, add the message to the list of ground truths
        msgs.append({
            'time': stamp,
            'trans': msg.transform.translation,
            'orien': msg.transform.rotation
        })

    bag.close()

    print('Noticed %d messages, will now sort them by time...' % len(msgs))

    msgs = sorted(msgs, key=lambda x: x['time'])

    with open(file, 'w') as f:
        print('Creating new file %s' % file)
        f.write(
            'Timestamp [s]\tTranslation X [m]\tTranslation Y [m]\tTranslation Z [m]\tOrientation W\tOrientation X\tOrientation Y\tOrientation Z\n'
        )
        for msg in msgs:
            f.write('%.6f\t' % msg['time'])

            f.write('%6.4f\t' % msg['trans'].x)
            f.write('%6.4f\t' % msg['trans'].y)
            f.write('%6.4f\t' % msg['trans'].z)

            f.write('%6.4f\t' % msg['orien'].w)
            f.write('%6.4f\t' % msg['orien'].x)
            f.write('%6.4f\t' % msg['orien'].y)
            f.write('%6.4f' % msg['orien'].z)
            f.write('\n')

        # remove last newline from file
        f.seek(-1, os.SEEK_END)
        f.truncate()