コード例 #1
0
ファイル: linemod_dataset.py プロジェクト: Guangyun-Xu/PVN3D
    def __init__(self, dataset_name, cls_type="duck"):
        self.config = Config(dataset_name='linemod', cls_type=cls_type)
        self.bs_utils = Basic_Utils()
        self.dataset_name = dataset_name
        self.xmap = np.array([[j for i in range(640)] for j in range(480)])
        self.ymap = np.array([[i for i in range(640)] for j in range(480)])

        self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.224])
        self.obj_dict = self.config.lm_obj_dict

        self.cls_type = cls_type
        self.cls_id = self.obj_dict[cls_type]
        print("cls_id in lm_dataset.py", self.cls_id)
        self.root = os.path.join(self.config.lm_root, 'Linemod_preprocessed')
        self.cls_root = os.path.join(self.root, "data/%02d/" % self.cls_id)
        self.rng = np.random
        meta_file = open(os.path.join(self.cls_root, 'gt.yml'), "r")
        self.meta_lst = yaml.load(meta_file)
        if dataset_name == 'train':
            self.add_noise = True
            real_img_pth = os.path.join(self.cls_root, "train.txt")
            self.real_lst = self.bs_utils.read_lines(real_img_pth)

            rnd_img_pth = os.path.join(
                self.root, "renders/{}/file_list.txt".format(cls_type))
            self.rnd_lst = self.bs_utils.read_lines(rnd_img_pth)

            fuse_img_pth = os.path.join(
                self.root, "fuse/{}/file_list.txt".format(cls_type))
            try:
                self.fuse_lst = self.bs_utils.read_lines(fuse_img_pth)
            except:  # no fuse dataset
                self.fuse_lst = self.rnd_lst
            self.all_lst = self.real_lst + self.rnd_lst + self.fuse_lst
        else:
            self.add_noise = False
            self.pp_data = None
            if os.path.exists(self.config.preprocessed_testset_pth
                              ) and self.config.use_preprocess:
                print('Loading valtestset.')
                with open(self.config.preprocessed_testset_pth, 'rb') as f:
                    self.pp_data = pkl.load(f)
                self.all_lst = [i for i in range(len(self.pp_data))]
                print('Finish loading valtestset.')
            else:
                tst_img_pth = os.path.join(self.cls_root, "test.txt")
                self.tst_lst = self.bs_utils.read_lines(tst_img_pth)
                self.all_lst = self.tst_lst
        print("{}_dataset_size: ".format(dataset_name), len(self.all_lst))
コード例 #2
0
ファイル: BOP_dataset.py プロジェクト: Guangyun-Xu/PVN3D
    def __init__(self, data_list_path, cls_id):
        # 方法调用
        self.config = Config(data_list_path, data_list_path)
        self.bs_utils = Basic_Utils()
        self.cls_id = cls_id
        # 参数设置
        self.voxel_size = self.config.voxel_size
        self.n_sample_points = self.config.n_sample_points
        self.add_noise = True

        # 数据获取
        self.data_list = self.bs_utils.read_lines(data_list_path)

        # 功能定义
        self.trans_color = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
コード例 #3
0
ファイル: common_BOP.py プロジェクト: Guangyun-Xu/PVN3D
class Config(object):
    def __init__(
        self,
        train_list_dir='./datasets/BOP/BOP_Dataset/LM-O/train_pbr/dataList_8.txt',
        test_list_dir='./datasets/BOP/BOP_Dataset/LM-O/train_pbr/dataList_8.txt'
    ):
        path = os.path.dirname(__file__)
        train_list_dir = train_list_dir.lstrip("./")
        train_list_dir = os.path.join(path, train_list_dir)
        test_list_dir = test_list_dir.lstrip("./")
        test_list_dir = os.path.join(path, test_list_dir)
        self.bs_utils = Basic_Utils()

        self.dataset_name = 'bop'
        self.exp_dir = os.path.dirname(__file__)
        self.exp_name = os.path.basename(self.exp_dir)
        self.resnet_ptr_mdl_p = os.path.abspath(
            os.path.join(self.exp_dir, 'lib/ResNet_pretrained_mdl'))
        ensure_fd(self.resnet_ptr_mdl_p)

        # log folder
        self.log_dir = os.path.abspath(
            os.path.join(self.exp_dir, 'train_log', self.dataset_name))
        ensure_fd(self.log_dir)
        self.log_model_dir = os.path.join(self.log_dir, 'checkpoints')
        ensure_fd(self.log_model_dir)
        self.log_eval_dir = os.path.join(self.log_dir, 'eval_results')
        ensure_fd(self.log_eval_dir)

        self.train_list = self.bs_utils.read_lines(train_list_dir)
        self.test_list = self.bs_utils.read_lines(test_list_dir)
        self.n_train_frame = len(self.train_list)
        self.n_test_frame = len(self.test_list)

        self.n_total_epoch = 10
        self.mini_batch_size = 2
        self.num_mini_batch_per_epoch = self.n_train_frame
        self.val_mini_batch_size = 2
        self.val_num_mini_batch_per_epoch = self.n_test_frame
        self.test_mini_batch_size = 4

        self.voxel_size = 0.01  # 以米为单位
        self.n_sample_points = 20000

        # dataset information
        self.LMO_obj_list = [1, 5, 6, 8, 9, 10, 11, 12]
コード例 #4
0
ファイル: lm_o_dataset.py プロジェクト: Guangyun-Xu/PVN3D
    def __init__(self, data_list_path, cls_id):

        # self.config = Config(dataset_name='linemod', cls_type=cls_type)
        self.bs_utils = Basic_Utils()
        self.n_sample_points = 20000

        self.xmap = np.array([[j for i in range(640)] for j in range(480)])
        self.ymap = np.array([[i for i in range(640)] for j in range(480)])

        self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.224])
        # self.obj_dict = self.config.lm_obj_dict

        self.cls_id = cls_id
        # self.cls_type = cls_type
        # self.cls_id = self.obj_dict[cls_type]
        print("load {} data by LM_O_Dataset.py".format(self.cls_id))
        # self.root = os.path.join(self.config.lm_root, 'Linemod_preprocessed')
        # self.cls_root = os.path.join(self.root, "data/%02d/" % self.cls_id)
        self.rng = np.random
        # meta_file = open(os.path.join(self.cls_root, 'gt.yml'), "r")
        # self.meta_lst = yaml.load(meta_file)
        self.dataList = self.bs_utils.read_lines(data_list_path)  # list

        # if dataset_name == 'train':
        #     self.add_noise = True
        #     real_img_pth = os.path.join(
        #         self.cls_root, "train.txt"
        #     )
        #     self.real_lst = self.bs_utils.read_lines(real_img_pth)
        #
        #     rnd_img_pth = os.path.join(
        #         self.root, "renders/{}/file_list.txt".format(cls_type)
        #     )
        #     self.rnd_lst = self.bs_utils.read_lines(rnd_img_pth)
        #
        #     fuse_img_pth = os.path.join(
        #         self.root, "fuse/{}/file_list.txt".format(cls_type)
        #     )
        #     try:
        #         self.fuse_lst = self.bs_utils.read_lines(fuse_img_pth)
        #     except:  # no fuse dataset
        #         self.fuse_lst = self.rnd_lst
        #     self.all_lst = self.real_lst + self.rnd_lst + self.fuse_lst
        # else:
        #     self.add_noise = False
        #     self.pp_data = None
        #     if os.path.exists(self.config.preprocessed_testset_pth) and self.config.use_preprocess:
        #         print('Loading valtestset.')
        #         with open(self.config.preprocessed_testset_pth, 'rb') as f:
        #             self.pp_data = pkl.load(f)
        #         self.all_lst = [i for i in range(len(self.pp_data))]
        #         print('Finish loading valtestset.')
        #     else:
        #         tst_img_pth = os.path.join(
        #             self.cls_root, "test.txt"
        #         )
        #         self.tst_lst = self.bs_utils.read_lines(tst_img_pth)
        #         self.all_lst = self.tst_lst
        print("{}_dataset_size: ".format(data_list_path), len(self.dataList))
コード例 #5
0
ファイル: lm_o_dataset.py プロジェクト: Guangyun-Xu/PVN3D
class LM_O_Dataset():
    def __init__(self, data_list_path, cls_id):

        # self.config = Config(dataset_name='linemod', cls_type=cls_type)
        self.bs_utils = Basic_Utils()
        self.n_sample_points = 20000

        self.xmap = np.array([[j for i in range(640)] for j in range(480)])
        self.ymap = np.array([[i for i in range(640)] for j in range(480)])

        self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.224])
        # self.obj_dict = self.config.lm_obj_dict

        self.cls_id = cls_id
        # self.cls_type = cls_type
        # self.cls_id = self.obj_dict[cls_type]
        print("load {} data by LM_O_Dataset.py".format(self.cls_id))
        # self.root = os.path.join(self.config.lm_root, 'Linemod_preprocessed')
        # self.cls_root = os.path.join(self.root, "data/%02d/" % self.cls_id)
        self.rng = np.random
        # meta_file = open(os.path.join(self.cls_root, 'gt.yml'), "r")
        # self.meta_lst = yaml.load(meta_file)
        self.dataList = self.bs_utils.read_lines(data_list_path)  # list

        # if dataset_name == 'train':
        #     self.add_noise = True
        #     real_img_pth = os.path.join(
        #         self.cls_root, "train.txt"
        #     )
        #     self.real_lst = self.bs_utils.read_lines(real_img_pth)
        #
        #     rnd_img_pth = os.path.join(
        #         self.root, "renders/{}/file_list.txt".format(cls_type)
        #     )
        #     self.rnd_lst = self.bs_utils.read_lines(rnd_img_pth)
        #
        #     fuse_img_pth = os.path.join(
        #         self.root, "fuse/{}/file_list.txt".format(cls_type)
        #     )
        #     try:
        #         self.fuse_lst = self.bs_utils.read_lines(fuse_img_pth)
        #     except:  # no fuse dataset
        #         self.fuse_lst = self.rnd_lst
        #     self.all_lst = self.real_lst + self.rnd_lst + self.fuse_lst
        # else:
        #     self.add_noise = False
        #     self.pp_data = None
        #     if os.path.exists(self.config.preprocessed_testset_pth) and self.config.use_preprocess:
        #         print('Loading valtestset.')
        #         with open(self.config.preprocessed_testset_pth, 'rb') as f:
        #             self.pp_data = pkl.load(f)
        #         self.all_lst = [i for i in range(len(self.pp_data))]
        #         print('Finish loading valtestset.')
        #     else:
        #         tst_img_pth = os.path.join(
        #             self.cls_root, "test.txt"
        #         )
        #         self.tst_lst = self.bs_utils.read_lines(tst_img_pth)
        #         self.all_lst = self.tst_lst
        print("{}_dataset_size: ".format(data_list_path), len(self.dataList))

    def get_meta_data(self, folderPath, sceneId, obj_id):
        metaFilePath = os.path.join(folderPath, "scene_gt.json")
        with open(metaFilePath, 'r') as f1:
            mateDate = json.load(f1)
            sceneId = sceneId.lstrip('0')
            if sceneId == '':
                sceneId = '0'
            sceneMateDate = mateDate[sceneId]
            for objMeta in sceneMateDate:
                if objMeta['obj_id'] == int(obj_id):
                    return objMeta

    def get_cam_parameter(self, folderPath, sceneId):
        sceneInfoPath = os.path.join(folderPath, "scene_camera.json")
        with open(sceneInfoPath, 'r') as f2:
            sceneInfo = json.load(f2)
            sceneId = sceneId.lstrip('0')
            if sceneId == '':
                sceneId = '0'
            sceneDate = sceneInfo[sceneId]
            return sceneDate

    def real_syn_gen(self, real_ratio=1.0):
        if self.rng.rand() < real_ratio:  # self.rng = np.random
            n_imgs = len(self.real_lst)  # 真实数据的数量
            idx = self.rng.randint(0, n_imgs)  # 将idx设置为0-n_imgs之间的整数
            pth = self.real_lst[idx]
            return pth
        else:
            fuse_ratio = 0.4
            if self.rng.rand() < fuse_ratio:
                idx = self.rng.randint(0, len(self.fuse_lst))
                pth = self.fuse_lst[idx]
            else:
                idx = self.rng.randint(0, len(self.rnd_lst))
                pth = self.rnd_lst[idx]
            return pth

    def real_gen(self):
        idx = self.rng.randint(0, len(self.real_lst))
        item = self.real_lst[idx]
        return item

    def rand_range(self, rng, lo, hi):
        return rng.rand() * (hi - lo) + lo

    def gaussian_noise(self, rng, img, sigma):
        """add gaussian noise of given sigma to image"""
        img = img + rng.randn(*img.shape) * sigma
        img = np.clip(img, 0, 255).astype('uint8')
        return img

    def linear_motion_blur(self, img, angle, length):
        """:param angle: in degree"""
        rad = np.deg2rad(angle)
        dx = np.cos(rad)
        dy = np.sin(rad)
        a = int(max(list(map(abs, (dx, dy)))) * length * 2)
        if a <= 0:
            return img
        kern = np.zeros((a, a))
        cx, cy = a // 2, a // 2
        dx, dy = list(map(int, (dx * length + cx, dy * length + cy)))
        cv2.line(kern, (cx, cy), (dx, dy), 1.0)
        s = kern.sum()
        if s == 0:
            kern[cx, cy] = 1.0
        else:
            kern /= s
        return cv2.filter2D(img, -1, kern)

    def rgb_add_noise(self, img):
        rng = self.rng
        # apply HSV augmentor
        if rng.rand() > 0:
            hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.uint16)
            hsv_img[:, :, 1] = hsv_img[:, :, 1] * self.rand_range(
                rng, 1 - 0.25, 1 + .25)
            hsv_img[:, :, 2] = hsv_img[:, :, 2] * self.rand_range(
                rng, 1 - .15, 1 + .15)
            hsv_img[:, :, 1] = np.clip(hsv_img[:, :, 1], 0, 255)
            hsv_img[:, :, 2] = np.clip(hsv_img[:, :, 2], 0, 255)
            img = cv2.cvtColor(hsv_img.astype(np.uint8), cv2.COLOR_HSV2BGR)

        if rng.rand() > 0.8:  # motion blur
            r_angle = int(rng.rand() * 360)
            r_len = int(rng.rand() * 15) + 1
            img = self.linear_motion_blur(img, r_angle, r_len)

        if rng.rand() > 0.8:
            if rng.rand() > 0.2:
                img = cv2.GaussianBlur(img, (3, 3), rng.rand())
            else:
                img = cv2.GaussianBlur(img, (5, 5), rng.rand())

        return np.clip(img, 0, 255).astype(np.uint8)

    # def get_normal(self, cld):  # 改成Open3D的
    #     cloud = pcl.
    #     cld = cld.astype(np.float32)
    #     cloud.from_array(cld)
    #     ne = cloud.make_NormalEstimation()
    #     kdtree = cloud.make_kdtree()
    #     ne.set_SearchMethod(kdtree)
    #     ne.set_KSearch(50)
    #     n = ne.compute()
    #     n = n.to_array()
    #     return n

    def get_normal(self, cld):
        # cldShape = cld.shape
        # normal = np.random.random(cldShape)
        # return normal
        cloud = o3d.geometry.PointCloud()
        cld = cld.astype(np.float32)
        cloud.points = o3d.Vector3dVector(cld)
        # cloud.points = o3d.utility.Vector3dVector(cld)
        o3d.geometry.estimate_normals(
            cloud, search_param=o3d.geometry.KDTreeSearchParamKNN(50))
        normal = np.asarray(cloud.normals)

        return normal

    def add_real_back(self, rgb, labels, dpt, dpt_msk):
        real_item = self.real_gen()
        with Image.open(
                os.path.join(self.cls_root,
                             "depth/{}.png".format(real_item))) as di:
            real_dpt = np.array(di)
        with Image.open(
                os.path.join(self.cls_root,
                             "mask/{}.png".format(real_item))) as li:
            bk_label = np.array(li)
        bk_label = (bk_label <= 0).astype(rgb.dtype)
        if len(bk_label.shape) < 3:
            bk_label_3c = np.repeat(bk_label[:, :, None], 3, 2)
        else:
            bk_label_3c = bk_label
            bk_label = bk_label[:, :, 0]
        with Image.open(
                os.path.join(self.cls_root,
                             "rgb/{}.png".format(real_item))) as ri:
            back = np.array(ri)[:, :, :3] * bk_label_3c
            back = back[:, :, ::-1].copy()
        dpt_back = real_dpt.astype(np.float32) * bk_label.astype(np.float32)

        msk_back = (labels <= 0).astype(rgb.dtype)
        msk_back = np.repeat(msk_back[:, :, None], 3, 2)
        imshow("msk_back", msk_back)
        rgb = rgb * (msk_back == 0).astype(rgb.dtype) + back * msk_back

        dpt = dpt * (dpt_msk > 0).astype(dpt.dtype) + \
              dpt_back * (dpt_msk <= 0).astype(dpt.dtype)
        return rgb, dpt

    def get_item(self, item_name):
        words = item_name.split()
        folderName = words[0]
        rgbName = words[1]
        sceneId = rgbName[:-4]
        depthName = words[2]
        segName = words[3]
        depthPath = os.path.join(folderName, "depth/{}".format(depthName))
        rgbPath = os.path.join(folderName, "rgb/{}".format(rgbName))
        segPath = os.path.join(folderName, "mask_visib/{}".format(segName))

        with Image.open(depthPath) as di:
            dpt = np.array(di)
        with Image.open(segPath) as li:
            labels = np.array(li)  # labels : mask
            labels = (labels > 0).astype("uint8")
        with Image.open(rgbPath) as ri:
            # if self.add_noise:
            #     ri = self.trancolor(ri)
            rgb = np.array(ri)[:, :, :3]

        meta = self.get_meta_data(folderName, sceneId,
                                  self.cls_id)  # meta 指的是目标物体的位姿和bbox

        R = np.resize(np.array(meta['cam_R_m2c']), (3, 3))
        T = np.array(meta['cam_t_m2c']) / 1000.0  # 以m为单位
        RT = np.concatenate((R, T[:, None]), axis=1)
        rnd_typ = 'real'
        camParameter = self.get_cam_parameter(folderName, sceneId)
        K = np.resize(np.array(camParameter['cam_K']), (3, 3))
        cam_scale = 10000.0  # BOP中的深度以0.1mm为单位, 转换成m需要除以10000

        rgb = rgb[:, :, ::-1].copy()  # # r b 互换
        msk_dp = dpt > 1e-6
        if len(labels.shape) > 2:
            labels = labels[:, :, 0]  # 转成单通道
        rgb_labels = labels.copy()

        rgb = np.transpose(rgb, (2, 0, 1))  # hwc2chw
        cld, choose = self.bs_utils.dpt_2_cld(
            dpt, cam_scale, K)  # k:内参, cam_scale: 设置为1.0,不知道什么含义
        # choose : 深度图中不为0的像素的索引

        labels = labels.flatten()[choose]  # labels : mask
        rgb_lst = []
        for ic in range(rgb.shape[0]):
            rgb_lst.append(rgb[ic].flatten()[choose].astype(np.float32))
        rgb_pt = np.transpose(np.array(rgb_lst), (1, 0)).copy()

        choose = np.array([choose])
        choose_2 = np.array([i for i in range(len(choose[0, :]))])

        if len(choose_2) < 400:  # 如果场景中点云的数量过少,返回None
            print("too faw points :{}".format(depthPath))
            return None

        if len(choose_2) > self.n_sample_points:
            c_mask = np.zeros(len(choose_2), dtype=int)
            c_mask[:self.n_sample_points] = 1
            np.random.shuffle(c_mask)
            choose_2 = choose_2[
                c_mask.nonzero()]  # c_mask: 随机的0 1 组成的数组,choose_2:用于降采样
        else:
            choose_2 = np.pad(choose_2,
                              (0, self.n_sample_points - len(choose_2)),
                              'wrap')

        cld_rgb = np.concatenate((cld, rgb_pt), axis=1)
        cld_rgb = cld_rgb[choose_2, :]
        cld = cld[choose_2, :]  # 进行降采样

        normal = self.get_normal(cld)[:, :3]
        normal[np.isnan(normal)] = 0.0

        cld_rgb_nrm = np.concatenate((cld_rgb, normal), axis=1)
        choose = choose[:, choose_2]  # 降采样后的像素对应的原图上的索引
        labels = labels[choose_2].astype(np.int32)

        RTs = np.zeros((2, 3, 4))
        kp3ds = np.zeros((2, 5, 3))
        ctr3ds = np.zeros((2, 3))
        cls_ids = np.zeros((2, 1))
        kp_targ_ofst = np.zeros((10000, 5, 3))
        ctr_targ_ofst = np.zeros((10000, 3))
        for i, cls_id in enumerate([1]):
            RTs[i] = RT
            r = RT[:, :3]
            t = RT[:, 3]

            # ctr = self.bs_utils.get_ctr(self.cls_type, ds_type="linemod")[:, None]
            # ctr = np.dot(ctr.T, r.T) + t
            # ctr3ds[i, :] = ctr[0]
            #msk_idx = np.where(labels == cls_id)[0]

            #target_offset = np.array(np.add(cld, -1.0 * ctr3ds[i, :]))
            #ctr_targ_ofst[msk_idx, :] = target_offset[msk_idx, :]
            cls_ids[i, :] = np.array([1])

            # key_kpts = ''
            # if self.config.n_keypoints == 8:
            #     kp_type = 'farthest'
            # else:
            #     kp_type = 'farthest{}'.format(self.config.n_keypoints)
            # kps = self.bs_utils.get_kps(
            #     self.cls_type, kp_type=kp_type, ds_type='linemod'
            # )
            # kps = np.dot(kps, r.T) + t
            # kp3ds[i] = kps

            #target = []
            # for kp in kps:
            #     target.append(np.add(cld, -1.0 * kp))
            # target_offset = np.array(target).transpose(1, 0, 2)  # [npts, nkps, c]
            # kp_targ_ofst[msk_idx, :, :] = target_offset[msk_idx, :, :]

        # rgb, pcld, cld_rgb_nrm, choose, kp_targ_ofst, ctr_targ_ofst, cls_ids, RTs, labels, kp_3ds, ctr_3ds
        if DEBUG:
            return torch.from_numpy(rgb.astype(np.float32)), \
                   torch.from_numpy(cld.astype(np.float32)), \
                   torch.from_numpy(cld_rgb_nrm.astype(np.float32)), \
                   torch.LongTensor(choose.astype(np.int32)), \
                   torch.LongTensor(cls_ids.astype(np.int32)), \
                   torch.LongTensor(labels.astype(np.int32)), \
                   torch.from_numpy(np.array(cam_scale).astype(np.float32))

        # choose: 降采样后的点对应的原深度图上的索引
        return torch.from_numpy(rgb.astype(np.float32)), \
               torch.from_numpy(cld.astype(np.float32)), \
               torch.from_numpy(cld_rgb_nrm.astype(np.float32)), \
               torch.LongTensor(choose.astype(np.int32)), \
               torch.LongTensor(cls_ids.astype(np.int32)), \
               torch.LongTensor(labels.astype(np.int32))

    def __len__(self):
        return len(self.dataList)

    # 接收一个索引,然后返回用于训练的数据和标签
    def __getitem__(self, idx):  # 调用函数实例时传入
        item_name = self.dataList[idx]
        data = self.get_item(item_name)
        return data
コード例 #6
0
ファイル: linemod_dataset.py プロジェクト: Guangyun-Xu/PVN3D
class LM_Dataset():
    def __init__(self, dataset_name, cls_type="duck"):
        self.config = Config(dataset_name='linemod', cls_type=cls_type)
        self.bs_utils = Basic_Utils()
        self.dataset_name = dataset_name
        self.xmap = np.array([[j for i in range(640)] for j in range(480)])
        self.ymap = np.array([[i for i in range(640)] for j in range(480)])

        self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.224])
        self.obj_dict = self.config.lm_obj_dict

        self.cls_type = cls_type
        self.cls_id = self.obj_dict[cls_type]
        print("cls_id in lm_dataset.py", self.cls_id)
        self.root = os.path.join(self.config.lm_root, 'Linemod_preprocessed')
        self.cls_root = os.path.join(self.root, "data/%02d/" % self.cls_id)
        self.rng = np.random
        meta_file = open(os.path.join(self.cls_root, 'gt.yml'), "r")
        self.meta_lst = yaml.load(meta_file)
        if dataset_name == 'train':
            self.add_noise = True
            real_img_pth = os.path.join(self.cls_root, "train.txt")
            self.real_lst = self.bs_utils.read_lines(real_img_pth)

            rnd_img_pth = os.path.join(
                self.root, "renders/{}/file_list.txt".format(cls_type))
            self.rnd_lst = self.bs_utils.read_lines(rnd_img_pth)

            fuse_img_pth = os.path.join(
                self.root, "fuse/{}/file_list.txt".format(cls_type))
            try:
                self.fuse_lst = self.bs_utils.read_lines(fuse_img_pth)
            except:  # no fuse dataset
                self.fuse_lst = self.rnd_lst
            self.all_lst = self.real_lst + self.rnd_lst + self.fuse_lst
        else:
            self.add_noise = False
            self.pp_data = None
            if os.path.exists(self.config.preprocessed_testset_pth
                              ) and self.config.use_preprocess:
                print('Loading valtestset.')
                with open(self.config.preprocessed_testset_pth, 'rb') as f:
                    self.pp_data = pkl.load(f)
                self.all_lst = [i for i in range(len(self.pp_data))]
                print('Finish loading valtestset.')
            else:
                tst_img_pth = os.path.join(self.cls_root, "test.txt")
                self.tst_lst = self.bs_utils.read_lines(tst_img_pth)
                self.all_lst = self.tst_lst
        print("{}_dataset_size: ".format(dataset_name), len(self.all_lst))

    def real_syn_gen(self, real_ratio=1.0):
        if self.rng.rand() < real_ratio:  # self.rng = np.random
            n_imgs = len(self.real_lst)  # 真实数据的数量
            idx = self.rng.randint(0, n_imgs)  # 将idx设置为0-n_imgs之间的整数
            pth = self.real_lst[idx]
            return pth
        else:
            fuse_ratio = 0.4
            if self.rng.rand() < fuse_ratio:
                idx = self.rng.randint(0, len(self.fuse_lst))
                pth = self.fuse_lst[idx]
            else:
                idx = self.rng.randint(0, len(self.rnd_lst))
                pth = self.rnd_lst[idx]
            return pth

    def real_gen(self):
        idx = self.rng.randint(0, len(self.real_lst))
        item = self.real_lst[idx]
        return item

    def rand_range(self, rng, lo, hi):
        return rng.rand() * (hi - lo) + lo

    def gaussian_noise(self, rng, img, sigma):
        """add gaussian noise of given sigma to image"""
        img = img + rng.randn(*img.shape) * sigma
        img = np.clip(img, 0, 255).astype('uint8')
        return img

    def linear_motion_blur(self, img, angle, length):
        """:param angle: in degree"""
        rad = np.deg2rad(angle)
        dx = np.cos(rad)
        dy = np.sin(rad)
        a = int(max(list(map(abs, (dx, dy)))) * length * 2)
        if a <= 0:
            return img
        kern = np.zeros((a, a))
        cx, cy = a // 2, a // 2
        dx, dy = list(map(int, (dx * length + cx, dy * length + cy)))
        cv2.line(kern, (cx, cy), (dx, dy), 1.0)
        s = kern.sum()
        if s == 0:
            kern[cx, cy] = 1.0
        else:
            kern /= s
        return cv2.filter2D(img, -1, kern)

    def rgb_add_noise(self, img):
        rng = self.rng
        # apply HSV augmentor
        if rng.rand() > 0:
            hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.uint16)
            hsv_img[:, :, 1] = hsv_img[:, :, 1] * self.rand_range(
                rng, 1 - 0.25, 1 + .25)
            hsv_img[:, :, 2] = hsv_img[:, :, 2] * self.rand_range(
                rng, 1 - .15, 1 + .15)
            hsv_img[:, :, 1] = np.clip(hsv_img[:, :, 1], 0, 255)
            hsv_img[:, :, 2] = np.clip(hsv_img[:, :, 2], 0, 255)
            img = cv2.cvtColor(hsv_img.astype(np.uint8), cv2.COLOR_HSV2BGR)

        if rng.rand() > 0.8:  # motion blur
            r_angle = int(rng.rand() * 360)
            r_len = int(rng.rand() * 15) + 1
            img = self.linear_motion_blur(img, r_angle, r_len)

        if rng.rand() > 0.8:
            if rng.rand() > 0.2:
                img = cv2.GaussianBlur(img, (3, 3), rng.rand())
            else:
                img = cv2.GaussianBlur(img, (5, 5), rng.rand())

        return np.clip(img, 0, 255).astype(np.uint8)

    # def get_normal(self, cld):  # 改成Open3D的
    #     cloud = pcl.
    #     cld = cld.astype(np.float32)
    #     cloud.from_array(cld)
    #     ne = cloud.make_NormalEstimation()
    #     kdtree = cloud.make_kdtree()
    #     ne.set_SearchMethod(kdtree)
    #     ne.set_KSearch(50)
    #     n = ne.compute()
    #     n = n.to_array()
    #     return n

    def get_normal(self, cld):
        cldShape = cld.shape
        normal = np.random.random(cldShape)
        return normal
        # cloud = o3d.geometry.PointCloud()
        # cld = cld.astype(np.float32)
        # cloud.points = o3d.utility.Vector3dVector(cld)
        # o3d.geometry.estimate_normals(cloud,
        #                              search_param=o3d.geometry.KDTreeSearchParamKNN(50))
        # print(cloud.normals)

    def add_real_back(self, rgb, labels, dpt, dpt_msk):
        real_item = self.real_gen()
        with Image.open(
                os.path.join(self.cls_root,
                             "depth/{}.png".format(real_item))) as di:
            real_dpt = np.array(di)
        with Image.open(
                os.path.join(self.cls_root,
                             "mask/{}.png".format(real_item))) as li:
            bk_label = np.array(li)
        bk_label = (bk_label <= 0).astype(rgb.dtype)
        if len(bk_label.shape) < 3:
            bk_label_3c = np.repeat(bk_label[:, :, None], 3, 2)
        else:
            bk_label_3c = bk_label
            bk_label = bk_label[:, :, 0]
        with Image.open(
                os.path.join(self.cls_root,
                             "rgb/{}.png".format(real_item))) as ri:
            back = np.array(ri)[:, :, :3] * bk_label_3c
            back = back[:, :, ::-1].copy()
        dpt_back = real_dpt.astype(np.float32) * bk_label.astype(np.float32)

        msk_back = (labels <= 0).astype(rgb.dtype)
        msk_back = np.repeat(msk_back[:, :, None], 3, 2)
        imshow("msk_back", msk_back)
        rgb = rgb * (msk_back == 0).astype(rgb.dtype) + back * msk_back

        dpt = dpt * (dpt_msk > 0).astype(dpt.dtype) + \
              dpt_back * (dpt_msk <= 0).astype(dpt.dtype)
        return rgb, dpt

    def get_item(self, item_name):
        try:  # 如果try中的语句块出现异常,执行except中的内容
            if "pkl" in item_name:  #
                data = pkl.load(open(item_name, "rb"))
                dpt = data['depth']
                rgb = data['rgb']
                labels = data['mask']
                K = data['K']
                RT = data['RT']
                rnd_typ = data['rnd_typ']
                if rnd_typ == "fuse":
                    labels = (labels == self.cls_id).astype("uint8")
                else:
                    labels = (labels > 0).astype("uint8")
                cam_scale = 1.0
            else:
                with Image.open(
                        os.path.join(self.cls_root,
                                     "depth/{}.png".format(item_name))) as di:
                    dpt = np.array(di)
                with Image.open(
                        os.path.join(self.cls_root,
                                     "mask/{}.png".format(item_name))) as li:
                    labels = np.array(li)
                    labels = (labels > 0).astype("uint8")
                with Image.open(
                        os.path.join(self.cls_root,
                                     "rgb/{}.png".format(item_name))) as ri:
                    if self.add_noise:
                        ri = self.trancolor(ri)
                    rgb = np.array(ri)[:, :, :3]
                meta = self.meta_lst[int(item_name)]  # meta 指的是目标物体的位姿和bbox
                if self.cls_id == 2:
                    for i in range(0, len(meta)):
                        if meta[i]['obj_id'] == 2:
                            meta = meta[i]
                            break
                else:
                    meta = meta[0]
                R = np.resize(np.array(meta['cam_R_m2c']), (3, 3))
                T = np.array(meta['cam_t_m2c']) / 1000.0
                RT = np.concatenate((R, T[:, None]), axis=1)
                rnd_typ = 'real'
                K = self.config.intrinsic_matrix["linemod"]
                cam_scale = 1000.0  # 用于单位换算,将深度图的单位由mm转换为m
            rgb = rgb[:, :, ::-1].copy()  # r b 互换
            msk_dp = dpt > 1e-6
            if len(labels.shape) > 2:
                labels = labels[:, :, 0]  # 转成单通道
            rgb_labels = labels.copy()

            if self.add_noise and rnd_typ == 'render':
                rgb = self.rgb_add_noise(rgb)
                rgb_labels = labels.copy()
                rgb, dpt = self.add_real_back(rgb, rgb_labels, dpt, msk_dp)
                if self.rng.rand() > 0.8:
                    rgb = self.rgb_add_noise(rgb)

            rgb = np.transpose(rgb, (2, 0, 1))  # hwc2chw
            cld, choose = self.bs_utils.dpt_2_cld(
                dpt, cam_scale, K)  # k:内参, cam_scale: 设置为1.0,不知道什么含义
            # choose : 深度图中不为0的像素的索引

            labels = labels.flatten()[choose]  # labels : mask
            rgb_lst = []
            for ic in range(rgb.shape[0]):
                rgb_lst.append(rgb[ic].flatten()[choose].astype(np.float32))
            rgb_pt = np.transpose(np.array(rgb_lst),
                                  (1, 0)).copy()  # 每个像素的rgb值

            choose = np.array([choose])
            choose_2 = np.array([i for i in range(len(choose[0, :]))])

            if len(choose_2) < 400:  # 如果场景中点云的数量过少,返回None
                return None
            if len(choose_2) > self.config.n_sample_points:
                c_mask = np.zeros(len(choose_2), dtype=int)
                c_mask[:self.config.n_sample_points] = 1
                np.random.shuffle(c_mask)
                choose_2 = choose_2[
                    c_mask.nonzero()]  # c_mask: 随机的0 1 组成的数组,choose_2:用于降采样
            else:
                choose_2 = np.pad(
                    choose_2, (0, self.config.n_sample_points - len(choose_2)),
                    'wrap')

            cld_rgb = np.concatenate((cld, rgb_pt), axis=1)
            cld_rgb = cld_rgb[choose_2, :]
            cld = cld[choose_2, :]

            normal = self.get_normal(cld)[:, :3]
            normal[np.isnan(normal)] = 0.0

            cld_rgb_nrm = np.concatenate((cld_rgb, normal), axis=1)
            choose = choose[:, choose_2]  # 降采样后的像素对应的原图上的索引
            labels = labels[choose_2].astype(np.int32)

            RTs = np.zeros((self.config.n_objects, 3, 4))
            kp3ds = np.zeros(
                (self.config.n_objects, self.config.n_keypoints, 3))
            ctr3ds = np.zeros((self.config.n_objects, 3))
            cls_ids = np.zeros((self.config.n_objects, 1))
            kp_targ_ofst = np.zeros(
                (self.config.n_sample_points, self.config.n_keypoints, 3))
            ctr_targ_ofst = np.zeros((self.config.n_sample_points, 3))
            for i, cls_id in enumerate([1]):
                RTs[i] = RT
                r = RT[:, :3]
                t = RT[:, 3]

                ctr = self.bs_utils.get_ctr(self.cls_type,
                                            ds_type="linemod")[:, None]
                ctr = np.dot(ctr.T, r.T) + t
                ctr3ds[i, :] = ctr[0]
                msk_idx = np.where(labels == cls_id)[0]

                target_offset = np.array(np.add(cld, -1.0 * ctr3ds[i, :]))
                ctr_targ_ofst[msk_idx, :] = target_offset[msk_idx, :]
                cls_ids[i, :] = np.array([1])

                key_kpts = ''
                if self.config.n_keypoints == 8:
                    kp_type = 'farthest'
                else:
                    kp_type = 'farthest{}'.format(self.config.n_keypoints)
                kps = self.bs_utils.get_kps(self.cls_type,
                                            kp_type=kp_type,
                                            ds_type='linemod')
                kps = np.dot(kps, r.T) + t
                kp3ds[i] = kps

                target = []
                for kp in kps:
                    target.append(np.add(cld, -1.0 * kp))
                target_offset = np.array(target).transpose(
                    1, 0, 2)  # [npts, nkps, c]
                kp_targ_ofst[msk_idx, :, :] = target_offset[msk_idx, :, :]

            # rgb, pcld, cld_rgb_nrm, choose, kp_targ_ofst, ctr_targ_ofst, cls_ids, RTs, labels, kp_3ds, ctr_3ds
            if DEBUG:
                return torch.from_numpy(rgb.astype(np.float32)), \
                       torch.from_numpy(cld.astype(np.float32)), \
                       torch.from_numpy(cld_rgb_nrm.astype(np.float32)), \
                       torch.LongTensor(choose.astype(np.int32)), \
                       torch.from_numpy(kp_targ_ofst.astype(np.float32)), \
                       torch.from_numpy(ctr_targ_ofst.astype(np.float32)), \
                       torch.LongTensor(cls_ids.astype(np.int32)), \
                       torch.from_numpy(RTs.astype(np.float32)), \
                       torch.LongTensor(labels.astype(np.int32)), \
                       torch.from_numpy(kp3ds.astype(np.float32)), \
                       torch.from_numpy(ctr3ds.astype(np.float32)), \
                       torch.from_numpy(K.astype(np.float32)), \
                       torch.from_numpy(np.array(cam_scale).astype(np.float32))

            # choose: 降采样后的点对应的原深度图上的索引
            return torch.from_numpy(rgb.astype(np.float32)), \
                   torch.from_numpy(cld.astype(np.float32)), \
                   torch.from_numpy(cld_rgb_nrm.astype(np.float32)), \
                   torch.LongTensor(choose.astype(np.int32)), \
                   torch.from_numpy(kp_targ_ofst.astype(np.float32)), \
                   torch.from_numpy(ctr_targ_ofst.astype(np.float32)), \
                   torch.LongTensor(cls_ids.astype(np.int32)), \
                   torch.from_numpy(RTs.astype(np.float32)), \
                   torch.LongTensor(labels.astype(np.int32)), \
                   torch.from_numpy(kp3ds.astype(np.float32)), \
                   torch.from_numpy(ctr3ds.astype(np.float32)),
        except:
            return None

    def __len__(self):
        return len(self.all_lst)

    # 接收一个索引,然后返回用于训练的数据和标签
    def __getitem__(self, idx):  # 调用函数实例时传入
        if self.dataset_name == 'train':
            item_name = self.real_syn_gen()  # 物品的名称
            data = self.get_item(item_name)  # 获得物品的数据
            while data is None:  # 如果没有成功获得物品的数据,循环执行上述两步,直到获得数据
                item_name = self.real_syn_gen()
                data = self.get_item(item_name)
            return data
        else:
            if self.pp_data is None or not self.config.use_preprocess:
                item_name = self.all_lst[idx]
                return self.get_item(item_name)
            else:
                data = self.pp_data[idx]
                return data
コード例 #7
0
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
from torch.nn.modules.loss import _Loss
from torch.autograd import Variable
import concurrent.futures
import numpy as np
import pickle as pkl
from pvn3d.common import Config
from pvn3d.lib.utils.basic_utils import Basic_Utils
from pvn3d.lib.utils.meanshift_pytorch import MeanShiftTorch


config = Config(dataset_name='ycb')
bs_utils = Basic_Utils()
config_lm = Config(dataset_name="linemod")
bs_utils_lm = Basic_Utils()
cls_lst = config.ycb_cls_lst


class VotingType:
    BB8=0
    BB8C=1
    BB8S=2
    VanPts=3
    Farthest=5
    Farthest4=6
    Farthest12=7
    Farthest16=8
    Farthest20=9
コード例 #8
0
ファイル: BOP_dataset.py プロジェクト: Guangyun-Xu/PVN3D
class BOPDataset():
    def __init__(self, data_list_path, cls_id):
        # 方法调用
        self.config = Config(data_list_path, data_list_path)
        self.bs_utils = Basic_Utils()
        self.cls_id = cls_id
        # 参数设置
        self.voxel_size = self.config.voxel_size
        self.n_sample_points = self.config.n_sample_points
        self.add_noise = True

        # 数据获取
        self.data_list = self.bs_utils.read_lines(data_list_path)

        # 功能定义
        self.trans_color = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)

    def get_cam_parameter(self, folderPath, sceneId):
        sceneInfoPath = os.path.join(folderPath, "scene_camera.json")
        with open(sceneInfoPath, 'r') as f2:
            sceneInfo = json.load(f2)
            sceneId = sceneId.lstrip('0')
            if sceneId == '':
                sceneId = '0'
            sceneDate = sceneInfo[sceneId]
            return sceneDate

    def get_normal(self, cld):
        cloud = o3d.geometry.PointCloud()
        cld = cld.astype(np.float32)
        cloud.points = o3d.Vector3dVector(cld)
        o3d.geometry.estimate_normals(
            cloud,
            search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01,
                                                              max_nn=100))
        normal = np.asarray(cloud.normals)
        return normal

    def get_point_normal(self, cld):
        cloud = o3d.geometry.PointCloud()
        cld = cld.astype(np.float32)
        cloud.points = o3d.Vector3dVector(cld)
        o3d.geometry.estimate_normals(
            cloud, search_param=o3d.geometry.KDTreeSearchParamKNN(50))
        return cloud

    def pcd_down_sample(self, src, voxel_size, n_sample_points, original_idx):
        # # numpy -> PointCloud
        # cloud = o3d.geometry.PointCloud()
        # src = src.astype(np.float32)
        # cloud.points = o3d.Vector3dVector(src)
        original_idx = np.array([original_idx, original_idx,
                                 original_idx]).transpose()
        color_idx = original_idx.tolist()
        src.colors = o3d.Vector3dVector(color_idx)

        max_bound = src.get_max_bound()
        min_bound = src.get_min_bound()
        # 截取物体范围内的点
        max_distance = 1.5  # m
        min_distance = 0.3  # m
        src_tree = o3d.KDTreeFlann(src)
        _, max_idx, _ = src_tree.search_radius_vector_3d([0, 0, 0],
                                                         max_distance)
        _, min_idx, _ = src_tree.search_radius_vector_3d([0, 0, 0],
                                                         min_distance)
        src = o3d.select_down_sample(src, max_idx)
        if len(min_idx) > 0:
            src_tree = o3d.KDTreeFlann(src)
            _, min_idx, _ = src_tree.search_radius_vector_3d([0, 0, 0],
                                                             min_distance)
            src = o3d.select_down_sample(src, min_idx, True)

        _, point_idx = o3d.voxel_down_sample_and_trace(src,
                                                       voxel_size=voxel_size,
                                                       min_bound=min_bound,
                                                       max_bound=max_bound)

        max_idx_in_row = np.max(point_idx, axis=1)
        pcd_down = o3d.select_down_sample(src, max_idx_in_row)
        # pcd_down_tree = o3d.KDTreeFlann(pcd_down)
        # _, idx, _ = pcd_down_tree.search_radius_vector_3d([0,0,0], 1500)
        # pcd_down = o3d.select_down_sample(pcd_down, idx)

        # 如果采样后的点数大于指定值, 随机减少一定数量的
        n_points = len(max_idx_in_row)

        np.random.seed(666)
        if n_points > n_sample_points:
            n_minus = n_points - n_sample_points
            n_pcd_dow = n_points
            minus_idx = np.random.choice(n_pcd_dow, n_minus, replace=False)
            pcd_down = o3d.select_down_sample(pcd_down, minus_idx, True)
            return pcd_down
        elif n_points < n_sample_points:
            n_add = n_sample_points - n_points
            n_cuted_points = len(src.points)
            n_unsample_points = n_cuted_points - max_idx_in_row.shape[0]
            if n_add < n_unsample_points:
                # select unsampled points

                unsample_points = o3d.select_down_sample(
                    src, max_idx_in_row, True)

                add_idx = np.random.choice(n_unsample_points,
                                           n_add,
                                           replace=False)
                add_points = o3d.select_down_sample(unsample_points, add_idx)
                pcd_down += add_points
                return pcd_down
            else:
                return None

        else:
            return pcd_down

    def get_item(self, item_name):
        root_path = os.path.dirname(__file__)
        words = item_name.split()
        folderName = words[0].lstrip("./")
        folderName = os.path.join(root_path, folderName)
        rgbName = words[1]
        sceneId = rgbName[:-4]
        depthName = words[2]
        segName = words[3]
        depthPath = os.path.join(folderName, "depth/{}".format(depthName))
        rgbPath = os.path.join(folderName, "rgb/{}".format(rgbName))
        segPath = os.path.join(folderName, "mask_visib/{}".format(segName))

        # 读取数据
        with Image.open(depthPath) as di:
            dpt = np.array(di)
        with Image.open(segPath) as li:
            labels = np.array(li)  # labels : mask
            labels = (labels > 0).astype("uint8")  # 转换为8位无符号整型数据,够用吗?
        with Image.open(rgbPath) as ri:
            if self.add_noise:
                ri = self.trans_color(ri)
            rgb = np.array(ri)[:, :, :3]

        # rgb预处理
        rgb = rgb[:, :, ::-1].copy()  # # r b 互换
        rgb = np.transpose(rgb, (2, 0, 1))  # hwc2chw

        # cld预处理
        cam_parameter = self.get_cam_parameter(folderName, sceneId)
        K = np.resize(np.array(cam_parameter['cam_K']), (3, 3))
        depth_scale = cam_parameter['depth_scale']
        cam_scale = 1000 / depth_scale  # BOP中的深度以0.1mm为单位, 转换成m需要除以10000

        cld, choose = self.bs_utils.dpt_2_cld(
            dpt, cam_scale, K)  # k:内参, cam_scale: 设置为1.0,不知道什么含义
        # choose : 深度图中不为0的像素的索引
        # 对choose重新排序
        choose_rerank = np.array([i for i in range(choose.shape[0])])
        cld_normal_o3d = self.get_point_normal(cld)
        cld_normal_down = self.pcd_down_sample(cld_normal_o3d, self.voxel_size,
                                               self.n_sample_points,
                                               choose_rerank)
        if cld_normal_down:
            original_idx_down = np.asarray(cld_normal_down.colors).astype(
                np.int)
            original_idx_down = original_idx_down[:, 0].tolist()
            cld_down = np.asarray(cld_normal_down.points)

            # cld_rgb_normal预处理
            rgb_lst = []
            for ic in range(rgb.shape[0]):
                rgb_lst.append(rgb[ic].flatten()[choose].astype(
                    np.float32))  # 提取点云对应的像素
            rgb_pt = np.transpose(np.array(rgb_lst), (1, 0)).copy()

            cld_rgb = np.concatenate((cld, rgb_pt), axis=1)
            cld_rgb = cld_rgb[original_idx_down, :]

            normal = np.asarray(cld_normal_down.normals)  # 计算法线
            normal[np.isnan(normal)] = 0.0
            cld_rgb_normal = np.concatenate((cld_rgb, normal), axis=1)

            # choose 预处理
            choose = np.array([choose])
            choose_dow = choose[:, original_idx_down]

            # cls_id 预处理 (作用不明)
            cls_ids = np.zeros((2, 1))

            # labels 预处理
            if len(labels.shape) > 2:
                labels = labels[:, :, 0]  # 转成单通道
            labels = labels.flatten()[choose][0]  # labels : mask
            labels = labels[original_idx_down].astype(np.int32)
            n_target = labels.nonzero()
            n_target = len(n_target[0])
            if n_target < 5:
                # print("n_target < 10")
                return None

            # choose: 降采样后的点对应的原深度图上的索引
            return torch.from_numpy(rgb.astype(np.float32)), \
                   torch.from_numpy(cld_down.astype(np.float32)), \
                   torch.from_numpy(cld_rgb_normal.astype(np.float32)), \
                   torch.LongTensor(choose_dow.astype(np.int32)), \
                   torch.LongTensor(cls_ids.astype(np.int32)), \
                   torch.LongTensor(labels.astype(np.int32))

        else:
            return None

    def __len__(self):
        return len(self.data_list)

    # 接收一个索引,然后返回用于训练的数据和标签
    def __getitem__(self, idx):  # 调用函数实例时传入
        # print("load {}th data...".format(idx))
        item_name = self.data_list[idx]
        data = self.get_item(item_name)
        while data is None:
            print("to few points:{}".format(idx))
            idx = np.random.randint(0, len(self.data_list))
            item_name = self.data_list[idx]
            print("replaced by :{}".format(idx))
            data = self.get_item(item_name)
        return data
コード例 #9
0
    help="Target dataset, ycb or linemod. (linemod as default).")
parser.add_argument(
    "-cls",
    type=str,
    default="ape",
    help="Target object to eval in LineMOD dataset. (ape, benchvise, cam, can,"
    + "cat, driller, duck, eggbox, glue, holepuncher, iron, lamp, phone)")
args = parser.parse_args()

args.dataset == "linemod"

if args.dataset == "ycb":
    config = Config(dataset_name=args.dataset)
else:
    config = Config(dataset_name=args.dataset, cls_type=args.cls)
bs_utils = Basic_Utils(config)


def ensure_fd(fd):
    if not os.path.exists(fd):
        os.system('mkdir -p {}'.format(fd))


def checkpoint_state(model=None,
                     optimizer=None,
                     best_prec=None,
                     epoch=None,
                     it=None):
    optim_state = optimizer.state_dict() if optimizer is not None else None
    if model is not None:
        if isinstance(model, torch.nn.DataParallel):