예제 #1
0
 def load_landmark_data(self, landmark_file):
     self.landmark_data = None
     try:
         lmd_data_log = get_log_table(landmark_file)
         self.landmark_data = {}
         for i in range(len(lmd_data_log)):
             new_point = BevPointData()
             new_point.load_table(lmd_data_log[i].table)
             landmark_id = new_point.type + "_" + str(new_point.id)
             self.landmark_data[landmark_id] = new_point
     except:
         print("no valid landmark data")
예제 #2
0
 def try_load_cur_labeling(self):
     try:
         tmplabels = get_log_table(self.cur_label_file)
         for la_data in tmplabels:
             x, y = la_data.table["x"], la_data.table["y"]
             tmp_point_id = la_data.table["id"]
             tmp_type_str = la_data.table["type"]
             CvMarker.draw_cross_marker(
                 self.cur_img, x, y,
                 str(tmp_point_id) + ":" + tmp_type_str)
             self.cur_point_id = tmp_point_id + 1
     except:
         pass
예제 #3
0
    def __compute_marked_pose(self, imgs_dir):
        label_lst, _ = get_walkfilelist(imgs_dir, ".txt")

        veh_pose_lst = []
        for i in range(len(label_lst)):
            try:
                label_data = get_log_table(label_lst[i])
                load_points = []
                for k in range(len(label_data)):
                    if "Point" == label_data[k].name:
                        new_point = BevPointData()
                        new_point.load_table(label_data[k].table)
                        load_points.append(new_point)

                if len(load_points) >= 2:
                    # compute veh pose in world with two landmark
                    vx0, vy0 = load_points[0].veh_x, load_points[0].veh_y
                    wx0, wy0 = load_points[0].world_x, load_points[0].world_y
                    vx1, vy1 = load_points[1].veh_x, load_points[1].veh_y
                    wx1, wy1 = load_points[1].world_x, load_points[1].world_y

                    h = np.array([[vy0, vx0, 1.0, 0.0], [vx0, -vy0, 0.0, 1.0],
                                  [vy1, vx1, 1.0, 0.0], [vx1, -vy1, 0.0, 1.0]])
                    H = np.matrix(h)
                    Hinv = LA.inv(H)
                    WP = np.matrix([wx0, wy0, wx1, wy1]).T
                    WV = Hinv * WP
                    n = np.sqrt(WV[0, 0]**2 + WV[1, 0]**2)
                    WV[0, 0] /= n
                    WV[1, 0] /= n

                    veh_pose = PoseData()
                    veh_pose.time = int(load_points[0].img_id)
                    veh_pose.Xm = WV[2, 0]
                    veh_pose.Ym = WV[3, 0]
                    veh_pose.Yawr = np.arccos(WV[0, 0])
                    if WV[1, 0] < 0:
                        veh_pose.Yawr *= -1
                    veh_pose_lst.append(veh_pose)
            except:
                pass

        if veh_pose_lst:
            print("%d pose load" % len(veh_pose_lst))
            gt_pose_file = os.path.join(
                os.path.split(imgs_dir)[0], "gt_pose_marked.txt")
            with open(gt_pose_file, "w") as f:
                for pose in veh_pose_lst:
                    f.write(str(pose) + "\n")
예제 #4
0
    def load_loc_data(self, veh_loc_file):
        self.loc_data = None
        try:
            loc_data_log = get_log_table(veh_loc_file)
            self.loc_data = {}
            for i in range(len(loc_data_log)):
                new_pose = PoseData()
                new_pose.time = loc_data_log[i].table["time"]
                new_pose.Xm = loc_data_log[i].table["Xm"]
                new_pose.Ym = loc_data_log[i].table["Ym"]
                new_pose.Yawr = loc_data_log[i].table["Yawr"]
                self.loc_data[new_pose.time] = new_pose
            print("read pose data, save veh coordinate and world coordinate")

        except:
            self.loc_data = None
            print("no valid pose data, only save veh coordinate")
예제 #5
0
    def display_poses(self, imgs_dir):
        try:
            from matplotlib import pyplot as plt
            pose_file = os.path.join(
                os.path.split(imgs_dir)[0], "gt_pose_marked.txt")
            poses = get_log_table(pose_file)
            plt.figure()
            X, Y, dX, dY = [], [], [], []
            for i in range(len(poses)):
                pose_data = poses[i].table
                X.append(pose_data["Xm"])
                Y.append(pose_data["Ym"])
                dX.append(np.sin(pose_data["Yawr"]))
                dY.append(np.cos(pose_data["Yawr"]))
            plt.plot(X, Y, ".")
            plt.show()

        except:
            pass
예제 #6
0
 def __merge_landmarks(self, imgs_dir):
     label_lst, _ = get_walkfilelist(imgs_dir, ".txt")
     load_points = []
     for i in range(len(label_lst)):
         try:
             label_data = get_log_table(label_lst[i])
             for k in range(len(label_data)):
                 if "Point" == label_data[k].name:
                     new_point = BevPointData()
                     new_point.load_table(label_data[k].table)
                     load_points.append(new_point)
         except:
             pass
     if load_points:
         print("%d landmarks found" % len(load_points))
         landmark_file = os.path.join(
             os.path.split(imgs_dir)[0], "landmarks.txt")
         with open(landmark_file, "w") as f:
             for lp in load_points:
                 f.write(str(lp) + "\n")
예제 #7
0
    def display_landmarks(self, imgs_dir):
        try:
            from matplotlib import pyplot as plt
            font = {'size': 8}
            landmark_file = os.path.join(
                os.path.split(imgs_dir)[0], "landmarks.txt")
            landmarks = get_log_table(landmark_file)
            plt.figure()
            X, Y = [], []
            for i in range(len(landmarks)):
                lm_data = landmarks[i].table
                s = lm_data["type"] + "_" + str(lm_data["id"])
                x = lm_data["world_x"]
                y = lm_data["world_y"]
                X.append(x)
                Y.append(y)
                plt.text(x, y, s, fontdict=font)
            plt.plot(X, Y, ".")
            plt.show()

        except:
            pass