Пример #1
0
    def display_stable_poses(self):
        """ Display stable poses """
        if self.database is None:
            print 'You must open a database first'
            print
            return True

        if self.dataset is None:
            print 'You must open a dataset first'
            print
            return True

        print 'Available objects:'
        for key in self.dataset.object_keys:
            print key
        print

        invalid_obj = True
        while invalid_obj:
            # get object name
            obj_name = raw_input('Enter object key: ')
            tokens = obj_name.split()
            if len(tokens) > 1:
                print 'Please provide only a single input'

            if obj_name.lower() == 'q':
                return True

            if obj_name not in self.dataset.object_keys:
                print 'Key %s not in database' %(obj_name)
                continue

            print 'Displaying', obj_name
            obj = self.dataset[obj_name]
            stable_poses = self.dataset.stable_poses(obj_name)

            if USE_ALAN:
                for stable_pose in stable_poses:
                    print 'Stable pose %s with p=%.3f' %(stable_pose.id, stable_pose.p)
                    vis.figure()
                    vis.mesh_stable_pose(obj.mesh, stable_pose,
                                         color=(0.5, 0.5, 0.5), style='surface')
                    vis.show()
            else:
                for stable_pose in stable_poses:
                    print 'Stable pose %s with p=%.3f' %(stable_pose.id, stable_pose.p)
                    mlab.figure(bgcolor=(1,1,1), size=(1000,1000))
                    vis.plot_stable_pose(obj.mesh, stable_pose,
                                         color=(0.5, 0.5, 0.5), style='surface',
                                         d=self.config['table_extent'])
                    mlab.show()

            invalid_obj = False

        return True
Пример #2
0
    def display_object(self):
        """ Display an object """
        if self.database is None:
            print 'You must open a database first'
            print
            return True

        if self.dataset is None:
            print 'You must open a dataset first'
            print
            return True

        print 'Available objects:'
        for key in self.dataset.object_keys:
            print key
        print

        invalid_obj = True
        while invalid_obj:
            # get object name
            obj_name = raw_input('Enter object key: ')
            tokens = obj_name.split()
            if len(tokens) > 1:
                print 'Please provide only a single input'

            if obj_name.lower() == 'q':
                return True

            if obj_name not in self.dataset.object_keys:
                print 'Key %s not in database' %(obj_name)
                continue

            print 'Displaying', obj_name
            obj = self.dataset[obj_name]

            if USE_ALAN:
                vis.figure()
                vis.mesh(obj.mesh, color=(0.5, 0.5, 0.5), style='surface')
                vis.show()
            else:
                mlab.figure(bgcolor=(1,1,1), size=(1000,1000))
                vis.plot_mesh(obj.mesh, color=(0.5, 0.5, 0.5), style='surface')
                mlab.show()

            invalid_obj = False

        return True
    # get images
    depth_images, ir_intrinsics = load_depth_images(cfg)
    depth_im = Image.median_images(depth_images)
    point_cloud_cam = ir_intrinsics.deproject(depth_im)  # for debug only
    T_camera_world = RigidTransform.load(
        os.path.join(cfg['calib_dir'],
                     '%s_registration.tf' % (ir_intrinsics.frame)))
    point_cloud_world = T_camera_world * point_cloud_cam

    # create registration solver
    registration_solver = KnownObjectStablePoseTabletopRegistrationSolver(
        object_key, stp_id, database.datasets[0], cfg, args.output_path)
    registration_result = registration_solver.register(depth_im,
                                                       ir_intrinsics,
                                                       T_camera_world,
                                                       debug=cfg['debug'])

    # visualize setup
    T_camera_world = RigidTransform.load(
        os.path.join(cfg['calib_dir'],
                     '%s_registration.tf' % (ir_intrinsics.frame)))
    T_camera_obj = registration_result.T_camera_obj
    T_obj_world = T_camera_world * T_camera_obj.inverse()
    vis.figure()
    vis.mesh(obj.mesh, T_obj_world, color=(1, 1, 1))
    vis.points(point_cloud_world, color=(0, 1, 0), subsample=20)
    vis.table(dim=0.5)
    vis.view(focalpoint=(0, 0, 0))
    vis.show()
    # saving indexed results
    fig, axarr = plt.subplots(2, 5)
    axarr[0, 2].imshow(cnn_query_params.query_im.data)
    axarr[0, 2].set_title("Query Image")
    for i in range(5):
        axarr[1, i].imshow(nearest_images[i].image)
        axarr[1, i].set_title('{:.3f}'.format(dists[i]))

    for i in range(2):
        for j in range(5):
            axarr[i, j].axis('off')

    fig.suptitle("Query Image and Indexed Images")

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    fig.savefig(os.path.join(args.output_path, 'queried_images.pdf'),
                dpi=400,
                format='pdf')

    # display filtered point cloud with normals
    if cfg['vis_point_cloud']:
        vis.figure()
        vis.points(cnn_query_params.point_normal_cloud.points, subsample=2)
        vis.normals(cnn_query_params.point_normal_cloud.normals,
                    cnn_query_params.point_normal_cloud.points,
                    subsample=2,
                    color=(0, 0, 1))
        vis.show()
    def _perform_grasp(self, target_gripper_pose):
        # approach pose
        t_target_approach = np.array([0, 0, cfg['control']['approach_dist']])
        T_target_approach = RigidTransform(translation=t_target_approach,
                                           from_frame='gripper',
                                           to_frame='approach')
        approach_gripper_pose = target_gripper_pose * T_target_approach.inverse(
        )

        # lift pose
        t_lift = np.array([0, 0, cfg['control']['lift_height']])
        T_lift = RigidTransform(translation=t_lift,
                                from_frame='world',
                                to_frame='world')
        lift_gripper_pose = T_lift * target_gripper_pose

        # visualize setup
        vis.figure()
        vis.mesh(self.obj.mesh, self.T_obj_world, color=(1, 1, 1))
        vis.points(self.point_cloud_world, color=(0, 1, 0), subsample=20)
        vis.pose(self.T_camera_world, alpha=0.1)
        vis.pose(self.T_base_world, alpha=0.1)
        vis.table(dim=0.2)
        vis.pose(target_gripper_pose, alpha=0.1)
        vis.pose(approach_gripper_pose, alpha=0.1)
        vis.pose(lift_gripper_pose, alpha=0.1)
        vis.view(focalpoint=(0, 0, 0))
        vis.show()
        exit(0)

        # open YuMi
        self.robot.right.goto_pose(approach_gripper_pose)
        self.robot.right.goto_pose(target_gripper_pose)
        self.robot.right.close_gripper(wait_for_res=True)
        self.robot.right.goto_pose(lift_gripper_pose)
        sleep(5)

        # TODO: Take picture and save

        self.robot.right.open_gripper()
Пример #6
0
    def display_grasps(self):
        """ Display grasps for an object """
        if self.database is None:
            print 'You must open a database first'
            print
            return True

        if self.dataset is None:
            print 'You must open a dataset first'
            print
            return True

        # list grippers 
        print
        print 'Available grippers:'
        grippers = os.listdir(self.config['gripper_dir'])
        for gripper_name in grippers:
            print gripper_name
        print

        # set gripper
        invalid_gr = True
        while invalid_gr:
            # get object name
            gripper_name = raw_input('Enter gripper name: ')
            tokens = gripper_name.split()
            if len(tokens) > 1:
                print 'Please provide only a single input'

            if gripper_name.lower() == 'q':
                return True

            if gripper_name not in grippers:
                print 'Gripper %s not recognized' %(gripper_name)
                continue

            gripper = gr.RobotGripper.load(gripper_name)
            print 'Loaded gripper', gripper.name
            invalid_gr = False

        # list objects
        print 'Available objects:'
        for key in self.dataset.object_keys:
            print key
        print

        invalid_obj = True
        while invalid_obj:
            # get object name
            obj_name = raw_input('Enter object key: ')
            tokens = obj_name.split()
            if len(tokens) > 1:
                print 'Please provide only a single input'

            if obj_name.lower() == 'q':
                return True

            if obj_name not in self.dataset.object_keys:
                print 'Key %s not in database' %(obj_name)
                continue

            # list metrics
            print
            print 'Available metrics:'
            metrics = self.dataset.available_metrics(obj_name, gripper=gripper.name)
            for metric_name in metrics:
                print metric_name
            print

            # set gripper
            invalid_mt = True
            while invalid_mt:
                # get object name
                metric_name = raw_input('Enter metric name: ')
                tokens = metric_name.split()
                if len(tokens) > 1:
                    print 'Please provide only a single input'
                    
                if metric_name.lower() == 'q':
                    return True

                if metric_name not in metrics:
                    print 'Metric %s not recognized' %(metric_name)
                    continue
                
                print 'Using metric %s' %(metric_name)
                invalid_mt = False

            print 'Displaying grasps for gripper %s on object %s' %(gripper.name, obj_name)
            obj = self.dataset[obj_name] 
            grasps, metrics = self.dataset.sorted_grasps(obj_name, metric_name,
                                                         gripper=gripper.name)
                 
            if len(grasps) == 0:
                print 'No grasps for gripper %s on object %s' %(gripper.name, obj_name)
                return True
                         
            low = np.min(metrics)
            high = np.max(metrics)
            q_to_c = lambda quality: self.config['quality_scale'] * (quality - low) / (high - low)
      
            if USE_ALAN:
                raise ValueError('ALAN does not yet support grasp display')
            else:
                if self.config['show_gripper']:
                    i = 0
                    for grasp, metric in zip(grasps, metrics):
                        print 'Grasp %d %s=%.5f' %(grasp.grasp_id, metric_name, metric)
                        mlab.figure(bgcolor=(1,1,1), size=(1000,1000))
                        vis.plot_mesh(obj.mesh, color=(0.5, 0.5, 0.5), style='surface')
                        color = plt.get_cmap('hsv')(q_to_c(metric))[:-1]
                        vis.plot_gripper(grasp, gripper=gripper,
                                         color=color)
                        vis.plot_grasp(grasp, grasp_axis_color=color,
                                       endpoint_color=color)
                        mlab.show()
                        i += 1
                        if i >= self.config['max_plot_gripper']:
                            break
                else:
                    mlab.figure(bgcolor=(1,1,1), size=(1000,1000))
                    vis.plot_mesh(obj.mesh, color=(0.5, 0.5, 0.5), style='surface')
                    for grasp, metric in zip(grasps, metrics):
                        print 'Grasp %d %s=%.5f' %(grasp.grasp_id, metric_name, metric)
                        color = plt.get_cmap('hsv')(q_to_c(metric))[:-1]
                        vis.plot_grasp(grasp, grasp_axis_color=color,
                                       endpoint_color=color)
                    mlab.show()
            invalid_obj = False

        return True
Пример #7
0
    def _find_best_transformation(self,
                                  cnn_query_params,
                                  candidate_rendered_images,
                                  T_camera_table,
                                  dataset,
                                  config,
                                  debug=False):
        """ Finds the best transformation from the candidate set using Point to Plane Iterated closest point """
        T_camera_table = copy.copy(T_camera_table)
        T_camera_table._to_frame = 'table'  # for ease of use

        # read params
        icp_sample_size = config['icp_sample_size']
        icp_relative_point_plane_cost = config['icp_relative_point_plane_cost']
        icp_regularization_lambda = config['icp_regularization_lambda']
        feature_matcher_dist_thresh = config['feature_matcher_dist_thresh']
        feature_matcher_norm_thresh = config['feature_matcher_norm_thresh']
        num_registration_iters = config['num_registration_iters']
        compute_total_cost = config['compute_total_registration_cost']
        threshold_cost = config['threshold_cost']

        # register from nearest images
        registration_results = []
        min_cost = np.inf
        best_reg = None
        best_T_virtual_camera_obj = None
        best_index = -1
        for i, neighbor_image in enumerate(candidate_rendered_images):
            logging.info('Registering to neighbor %d' % (i))

            # load object mesh
            database_start = time.time()
            mesh = dataset.mesh(neighbor_image.obj_key)
            database_stop = time.time()

            # form transforms
            preproc_start = time.time()
            T_virtual_camera_stp = neighbor_image.stp_to_camera_transform(
            ).inverse()
            T_virtual_camera_stp._from_frame = cnn_query_params.point_normal_cloud.frame
            T_obj_stp = neighbor_image.object_to_stp_transform()
            source_mesh_x0_obj = Point(neighbor_image.stable_pose.x0,
                                       frame='obj')
            source_mesh_x0_stp = T_obj_stp * source_mesh_x0_obj
            block1_stop = time.time()
            logging.debug('Preproc block 1 took %.2f sec' %
                          (block1_stop - preproc_start))

            # get source object points in table basis
            logging.info('Transforming source mesh')
            mesh_stp = mesh.transform(T_obj_stp)
            z = mesh_stp.min_coords()[2]
            obj_center_table = np.array([0, 0, -z])

            T_obj_table = RigidTransform(rotation=T_obj_stp.rotation,
                                         translation=obj_center_table,
                                         from_frame='obj',
                                         to_frame='table')
            mesh_table = mesh.transform(T_obj_table)
            source_points_table = PointCloud(np.array(mesh_table.vertices()).T,
                                             frame=T_obj_table.to_frame)
            block2_stop = time.time()
            logging.debug('Preproc block 2 took %.2f sec' %
                          (block2_stop - block1_stop))

            # read target points and normals
            target_points_normals_virtual_camera = cnn_query_params.point_normal_cloud
            if target_points_normals_virtual_camera.num_points == 0:
                logging.info('Found zero target points, skipping')
                registration_results.append(
                    RegistrationResult(
                        RigidTransform(
                            from_frame=target_points_normals_virtual_camera.
                            frame,
                            to_frame='obj'), np.inf))
                continue
            block3_stop = time.time()
            logging.debug('Preproc block 3 took %.2f sec' %
                          (block3_stop - block2_stop))

            # match table normals
            logging.info('Matching table normals')
            target_points_stp = T_virtual_camera_stp * target_points_normals_virtual_camera.points
            target_normals_stp = T_virtual_camera_stp * target_points_normals_virtual_camera.normals

            T_stp_table = self._table_to_stp_transform(T_virtual_camera_stp,
                                                       T_camera_table, config)
            target_points_table = T_stp_table * target_points_stp
            target_normals_table = T_stp_table * target_normals_stp
            block4_stop = time.time()
            logging.debug('Preproc block 4 took %.2f sec' %
                          (block4_stop - block3_stop))

            # render depth image of source points
            logging.info('Rendering virtual depth')

            # transform mesh to virtual camera basis
            T_virtual_camera_table = T_stp_table * T_virtual_camera_stp
            T_table_virtual_camera = T_virtual_camera_table.inverse()
            mesh_virtual_camera = mesh_table.transform(T_table_virtual_camera)
            block5_stop = time.time()
            logging.debug('Preproc block 5 took %.2f sec' %
                          (block5_stop - block4_stop))

            # render virtual depth image
            T_camera_virtual_camera = cnn_query_params.T_camera_virtual_camera
            virtual_camera_intr = cnn_query_params.cropped_ir_intrinsics
            depth_im = mesh_virtual_camera.project_depth(virtual_camera_intr)
            source_depth_im = DepthImage(depth_im, virtual_camera_intr.frame)
            block6_stop = time.time()
            logging.debug('Preproc block 6 took %.2f sec' %
                          (block6_stop - block5_stop))

            # project points
            source_points_normals_virtual_camera = source_depth_im.point_normal_cloud(
                virtual_camera_intr)
            source_points_normals_virtual_camera.remove_zero_points()
            source_points_table = T_virtual_camera_table * source_points_normals_virtual_camera.points
            source_normals_table = T_virtual_camera_table * source_points_normals_virtual_camera.normals
            block7_stop = time.time()
            logging.debug('Preproc block 7 took %.2f sec' %
                          (block7_stop - block6_stop))

            # align the lowest and closest points to the camera
            logging.info('Aligning lowest and closest')
            table_center_camera = Point(T_camera_table.inverse().translation,
                                        frame=T_camera_table.from_frame)
            table_x0_table = T_virtual_camera_table * T_camera_virtual_camera * table_center_camera
            block8_stop = time.time()
            logging.debug('Preproc block 8 took %.2f sec' %
                          (block8_stop - block7_stop))

            # align points closest to the camera
            camera_optical_axis_table = -T_virtual_camera_table.rotation[:, 2]
            source_ip = source_points_table.data.T.dot(
                camera_optical_axis_table)
            closest_ind = np.where(source_ip == np.max(source_ip))[0]
            source_x0_closest_table = source_points_table[closest_ind[0]]

            max_z_ind = np.where(source_points_table.z_coords == np.max(
                source_points_table.z_coords))[0][0]
            source_x0_highest_table = source_points_table[
                max_z_ind]  # lowest point in table frame

            target_ip = target_points_table.data.T.dot(
                camera_optical_axis_table)
            closest_ind = np.where(target_ip == np.max(target_ip))[0]
            target_x0_closest_table = target_points_table[closest_ind[0]]

            max_z_ind = np.where(target_points_table.z_coords == np.max(
                target_points_table.z_coords))[0][0]
            target_x0_highest_table = target_points_table[
                max_z_ind]  # lowest point in table frame

            t_table_t_table_s = source_x0_closest_table.data - target_x0_closest_table.data
            t_table_t_table_s[2] = source_x0_highest_table.data[
                2] - target_x0_highest_table.data[2]

            T_table_t_table_s = RigidTransform(translation=t_table_t_table_s,
                                               from_frame='table',
                                               to_frame='table')
            target_points_table = T_table_t_table_s * target_points_table
            target_normals_table = T_table_t_table_s * target_normals_table
            T_virtual_camera_table = T_table_t_table_s * T_virtual_camera_table
            preproc_stop = time.time()
            block9_stop = time.time()
            logging.debug('Preproc block 9 took %.2f sec' %
                          (block9_stop - block8_stop))

            # display the points relative to one another
            if debug:
                logging.info('Pre-registration alignment')
                vis.figure()
                vis.mesh(mesh_table)
                vis.points(source_points_table, color=(1, 0, 0))
                vis.points(target_points_table, color=(0, 1, 0))
                vis.points(source_x0_closest_table,
                           color=(0, 0, 1),
                           scale=0.02)
                vis.points(target_x0_closest_table,
                           color=(0, 1, 1),
                           scale=0.02)
                vis.table(dim=0.15)
                vis.show()

            # point to plane ICP solver
            icp_start = time.time()
            ppis = PointToPlaneICPSolver(sample_size=icp_sample_size,
                                         gamma=icp_relative_point_plane_cost,
                                         mu=icp_regularization_lambda)
            ppfm = PointToPlaneFeatureMatcher(
                dist_thresh=feature_matcher_dist_thresh,
                norm_thresh=feature_matcher_norm_thresh)
            registration = ppis.register_2d(
                source_points_table,
                target_points_table,
                source_normals_table,
                target_normals_table,
                ppfm,
                num_iterations=num_registration_iters,
                compute_total_cost=compute_total_cost,
                vis=debug)

            registration_results.append(registration)
            icp_stop = time.time()

            logging.info('Neighbor %d registration cost %f' %
                         (i, registration.cost))
            logging.info('Neighbor %d timings' % (i))
            logging.info('Database read took %.2f sec' %
                         (database_stop - database_start))
            logging.info('Preproc took %.2f sec' %
                         (preproc_stop - preproc_start))
            logging.info('ICP took %.2f sec' % (icp_stop - icp_start))

            if debug:
                logging.info('Post-registration alignment')
                vis.figure()
                vis.points(registration.T_source_target * source_points_table,
                           color=(1, 0, 0))
                vis.points(target_points_table, color=(0, 1, 0))
                vis.table(dim=0.15)
                vis.show()

            if registration.cost < min_cost:
                min_cost = registration.cost
                best_reg = registration
                best_T_table_s_table_t = registration.T_source_target
                best_T_virtual_camera_obj = T_obj_table.inverse().dot(
                    best_T_table_s_table_t.inverse()).dot(
                        T_virtual_camera_table)
                best_index = i

            if min_cost < threshold_cost:
                logging.info(
                    'Satisfactory registration found. Terminating early.')
                break

        if debug:
            logging.info('Best alignment')
            vis.figure()
            vis.mesh(mesh)
            vis.points(best_T_virtual_camera_obj *
                       target_points_normals_virtual_camera.points,
                       color=(1, 0, 0))
            vis.show()

        # compute best transformation from object to camera basis
        return best_T_virtual_camera_obj, registration_results, best_index