コード例 #1
0
def cal_view_pred_pose(model, data, epoch=0, obj_id=-1):
    model.eval()
    with torch.set_grad_enabled(False):
        cu_dt = [item.to("cuda", non_blocking=True) for item in data]
        rgb, pcld, cld_rgb_nrm, choose, kp_targ_ofst, ctr_targ_ofst, \
            cls_ids, rts, labels, kp_3ds, ctr_3ds = cu_dt

        pred_kp_of, pred_rgbd_seg, pred_ctr_of = model(cld_rgb_nrm, rgb,
                                                       choose)
        _, classes_rgbd = torch.max(pred_rgbd_seg, -1)

        if args.dataset == "ycb":
            pred_cls_ids, pred_pose_lst = cal_frame_poses(
                pcld[0], classes_rgbd[0], pred_ctr_of[0], pred_kp_of[0], True,
                config.n_objects, True)
        else:
            pred_pose_lst = cal_frame_poses_lm(pcld[0], classes_rgbd[0],
                                               pred_ctr_of[0], pred_kp_of[0],
                                               True, config.n_objects, False,
                                               obj_id)
            pred_cls_ids = np.array([[1]])

        np_rgb = rgb.cpu().numpy().astype("uint8")[0].transpose(1, 2, 0).copy()
        if args.dataset == "ycb":
            np_rgb = np_rgb[:, :, ::-1].copy()
        ori_rgb = np_rgb.copy()
        for cls_id in cls_ids[0].cpu().numpy():
            idx = np.where(pred_cls_ids == cls_id)[0]
            if len(idx) == 0:
                continue
            pose = pred_pose_lst[idx[0]]
            if args.dataset == "ycb":
                obj_id = int(cls_id[0])
            mesh_pts = bs_utils.get_pointxyz(obj_id,
                                             ds_type=args.dataset).copy()
            mesh_pts = np.dot(mesh_pts, pose[:, :3].T) + pose[:, 3]
            if args.dataset == "ycb":
                K = config.intrinsic_matrix["ycb_K1"]
            else:
                K = config.intrinsic_matrix["linemod"]
            mesh_p2ds = bs_utils.project_p3d(mesh_pts, 1.0, K)
            color = bs_utils.get_label_color(obj_id, n_obj=22, mode=1)
            np_rgb = bs_utils.draw_p2ds(np_rgb, mesh_p2ds, color=color)
        vis_dir = os.path.join(config.log_eval_dir, "pose_vis")
        ensure_fd(vis_dir)
        f_pth = os.path.join(vis_dir, "{}.jpg".format(epoch))
        cv2.imwrite(f_pth, np_rgb)
        # imshow("projected_pose_rgb", np_rgb)
        # imshow("ori_rgb", ori_rgb)
        # waitKey(1)
    if epoch == 0:
        print("\n\nResults saved in {}".format(vis_dir))
コード例 #2
0
def cal_view_pred_pose(model, data):
    model.eval()
    try:
        #print('Started data acquisition')
        with torch.set_grad_enabled(False): # Data Acquisition
            cu_dt = [item.contiguous().to("cuda", non_blocking=True) for item in data]
            rgb, cld_rgb_nrm, choose = cu_dt#.contiguous()

            # Model Predictions #
            pred_kp_of, pred_rgbd_seg, pred_ctr_of = model(
                cld_rgb_nrm, rgb, choose
            )
            _, classes_rgbd = torch.max(pred_rgbd_seg, -1)
    
            if args.dataset == "ycb":
                pred_cls_ids, pred_pose_lst, pred_kps_lst = cal_frame_poses(
                    cld_rgb_nrm[0][:,:3], classes_rgbd[0], pred_ctr_of[0], pred_kp_of[0], True,
                    config.n_objects, True, args.dataset
                )

            elif args.dataset == "openDR":
                pred_cls_ids, pred_pose_lst, pred_kps_lst = cal_frame_poses(
                    cld_rgb_nrm[0][:,:3], classes_rgbd[0], pred_ctr_of[0], pred_kp_of[0], True,
                    config.n_objects, True, args.dataset
                )
    
            else:
                    pred_pose_lst, pred_kps_lst = cal_frame_poses_lm(
                    cld_rgb_nrm[0][:,:3], classes_rgbd[0], pred_ctr_of[0], pred_kp_of[0], True,
                    config.n_objects, False, 16
                )
    

            print('Prediction Complete...')
            return classes_rgbd.cpu().numpy(), pred_pose_lst, cld_rgb_nrm[0].cpu().numpy(), pred_kps_lst
    except Exception as inst:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        print('exception: '+str(inst)+' in '+ str(exc_tb.tb_lineno))
コード例 #3
0
ファイル: api.py プロジェクト: DarkGeekMS/PVN3D
 def get_poses(self, save_results=True):
     # perform inference and return objects' poses
     # model to eval mode
     self.model.eval()
     # perform inference on defined model
     with torch.set_grad_enabled(False):
         # network forward pass
         pred_kp_of, pred_rgbd_seg, pred_ctr_of = self.model(
             self.cld_rgb_nrm, self.rgb, self.choose)
         _, classes_rgbd = torch.max(pred_rgbd_seg, -1)
         # calculate poses by voting, clustering and linear fitting
         pred_cls_ids, pred_pose_lst = cal_frame_poses(
             self.cld[0], classes_rgbd[0], pred_ctr_of[0], pred_kp_of[0],
             True, self.config.n_objects, True)
         # visualize predicted poses
         if save_results:
             np_rgb = self.rgb.cpu().numpy().astype("uint8")[0].transpose(
                 1, 2, 0).copy()
             np_rgb = np_rgb[:, :, ::-1].copy()
             ori_rgb = np_rgb.copy()
             # loop over each class id
             for cls_id in self.cls_id_lst[0].cpu().numpy():
                 idx = np.where(pred_cls_ids == cls_id)[0]
                 if len(idx) == 0:
                     continue
                 pose = pred_pose_lst[idx[0]]
                 obj_id = int(cls_id)
                 mesh_pts = self.bs_utils.get_pointxyz(
                     obj_id, ds_type='ycb').copy()
                 mesh_pts = np.dot(mesh_pts, pose[:, :3].T) + pose[:, 3]
                 K = self.config.intrinsic_matrix["ycb_K1"]
                 mesh_p2ds = self.bs_utils.project_p3d(mesh_pts, 1.0, K)
                 color = self.bs_utils.get_label_color(obj_id,
                                                       n_obj=22,
                                                       mode=1)
                 np_rgb = self.bs_utils.draw_p2ds(np_rgb,
                                                  mesh_p2ds,
                                                  color=color)
             # save output visualization
             vis_dir = os.path.join(self.config.log_eval_dir, "pose_vis")
             if not os.path.exists(vis_dir):
                 os.system('mkdir -p {}'.format(vis_dir))
             f_pth = os.path.join(vis_dir, "out.jpg")
             cv2.imwrite(f_pth, np_rgb)
     # return prediction
     return pred_cls_ids, pred_pose_lst
コード例 #4
0
ファイル: demo_ycb_test.py プロジェクト: kaixin-bai/PVN3D
def cal_view_pred_pose(model, data, epoch=0, obj_id=-1):
    model.eval()
    with torch.set_grad_enabled(False):
        cu_dt = [item.to("cuda", non_blocking=True) for item in data]
        # rgb:[1,3,480,640]    pcld:[1,122881,3]
        rgb, pcld, cld_rgb_nrm, choose, kp_targ_ofst, ctr_targ_ofst, \
        cls_ids, rts, labels, kp_3ds, ctr_3ds = cu_dt

        pred_kp_of, pred_rgbd_seg, pred_ctr_of = model(
            cld_rgb_nrm, rgb, choose
        )
        _, classes_rgbd = torch.max(pred_rgbd_seg, -1)

        # ----------------------------------------------------------------------------------------------------------
        # 3D visualization
        from copy import deepcopy
        def vis3d(cld_rgb_nrm):
            cld_rgb_nrm = deepcopy(cld_rgb_nrm)
            np_xyz = cld_rgb_nrm.cpu().numpy()[0][:, :3]
            np_rgb = cld_rgb_nrm.cpu().numpy()[0][:, 3:6] / 255.
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(np_xyz)
            pcd.colors = o3d.utility.Vector3dVector(np_rgb)
            o3d.visualization.draw_geometries([pcd])

        # vis3d(cld_rgb_nrm)
        # ----------------------------------------------------------------------------------------------------------

        if args.dataset == "ycb":
            pred_cls_ids, pred_pose_lst = cal_frame_poses(
                pcld[0], classes_rgbd[0], pred_ctr_of[0], pred_kp_of[0], True,  # classes_rgbd[0]应该是mask才对
                config.n_objects, True
            )
        else:
            pred_pose_lst = cal_frame_poses_lm(
                pcld[0], classes_rgbd[0], pred_ctr_of[0], pred_kp_of[0], True,
                config.n_objects, False, obj_id
            )
            pred_cls_ids = np.array([[1]])

        np_rgb = rgb.cpu().numpy().astype("uint8")[0].transpose(1, 2, 0).copy()
        if args.dataset == "ycb":
            np_rgb = np_rgb[:, :, ::-1].copy()
        ori_rgb = np_rgb.copy()
        for cls_id in cls_ids[0].cpu().numpy():
            idx = np.where(pred_cls_ids == cls_id)[0]
            if len(idx) == 0:
                continue
            pose = pred_pose_lst[idx[0]]
            if args.dataset == "ycb":
                obj_id = int(cls_id[0])
            mesh_pts = bs_utils.get_pointxyz(obj_id, ds_type=args.dataset).copy()
            mesh_pts = np.dot(mesh_pts, pose[:, :3].T) + pose[:, 3]
            if args.dataset == "ycb":
                K = config.intrinsic_matrix["ycb_K1"]
            else:
                K = config.intrinsic_matrix["linemod"]

            # ----------------------------------------------------------------------------------------------------------
            # 3D visualization
            from copy import deepcopy
            def vis3d(cld_rgb_nrm):
                cld_rgb_nrm = deepcopy(cld_rgb_nrm)
                np_xyz = cld_rgb_nrm.cpu().numpy()[0][:, :3]
                np_rgb = cld_rgb_nrm.cpu().numpy()[0][:, 3:6] / 255.
                np_rgb[:, [0, 2]] = np_rgb[:, [2, 0]]
                pcd = o3d.geometry.PointCloud()
                pcd.points = o3d.utility.Vector3dVector(np_xyz)
                pcd.colors = o3d.utility.Vector3dVector(np_rgb)
                obj_model = o3d.geometry.PointCloud()
                obj_model.points = o3d.utility.Vector3dVector(mesh_pts)
                obj_model.paint_uniform_color(color=[0, 0, 1])
                o3d.visualization.draw_geometries([pcd, obj_model])
            vis3d(cld_rgb_nrm)
            # ----------------------------------------------------------------------------------------------------------

            mesh_p2ds = bs_utils.project_p3d(mesh_pts, 1.0, K)
            color = bs_utils.get_label_color(obj_id, n_obj=22, mode=1)
            np_rgb = bs_utils.draw_p2ds(np_rgb, mesh_p2ds, color=color)
        vis_dir = os.path.join(config.log_eval_dir, "pose_vis")
        ensure_fd(vis_dir)
        f_pth = os.path.join(vis_dir, "{}.jpg".format(epoch))

        # --------------------------------------------------------------------------------------------------------------
        # 2D visualization
        # cv2.imwrite(f_pth, np_rgb)
        imshow("projected_pose_rgb", np_rgb)
        # imshow("ori_rgb", ori_rgb)
        waitKey(0)
        # --------------------------------------------------------------------------------------------------------------

    if epoch == 0:
        print("\n\nResults saved in {}".format(vis_dir))