示例#1
0
    def read_data(self, dataset_dict):
        """load image and annos random shift & scale bbox; crop, rescale."""
        cfg = self.cfg
        r_head_cfg = cfg.MODEL.CDPN.ROT_HEAD
        pnp_net_cfg = cfg.MODEL.CDPN.PNP_NET

        dataset_dict = copy.deepcopy(
            dataset_dict)  # it will be modified by code below

        dataset_name = dataset_dict["dataset_name"]

        image = read_image_cv2(dataset_dict["file_name"],
                               format=self.img_format)
        # should be consistent with the size in dataset_dict
        utils.check_image_size(dataset_dict, image)
        im_H_ori, im_W_ori = image.shape[:2]

        # currently only replace bg for train ###############################
        if self.split == "train":
            # some synthetic data already has bg, img_type should be real or something else but not syn
            img_type = dataset_dict.get("img_type", "real")
            if img_type == "syn":
                log_first_n(logging.WARNING, "replace bg", n=10)
                assert "segmentation" in dataset_dict["inst_infos"]
                mask = cocosegm2mask(
                    dataset_dict["inst_infos"]["segmentation"], im_H_ori,
                    im_W_ori)
                image, mask_trunc = self.replace_bg(image.copy(),
                                                    mask,
                                                    return_mask=True)
            else:  # real image
                if np.random.rand() < cfg.INPUT.CHANGE_BG_PROB:
                    log_first_n(logging.WARNING, "replace bg for real", n=10)
                    assert "segmentation" in dataset_dict["inst_infos"]
                    mask = cocosegm2mask(
                        dataset_dict["inst_infos"]["segmentation"], im_H_ori,
                        im_W_ori)
                    image, mask_trunc = self.replace_bg(image.copy(),
                                                        mask,
                                                        return_mask=True)
                else:
                    mask_trunc = None

        # NOTE: maybe add or change color augment here ===================================
        if self.split == "train" and self.color_aug_prob > 0 and self.color_augmentor is not None:
            if np.random.rand() < self.color_aug_prob:
                if cfg.INPUT.COLOR_AUG_SYN_ONLY and img_type not in ["real"]:
                    image = self._color_aug(image, self.color_aug_type)
                else:
                    image = self._color_aug(image, self.color_aug_type)

        # other transforms (mainly geometric ones);
        # for 6d pose task, flip is now allowed in general except for some 2d keypoints methods
        image, transforms = T.apply_augmentations(self.augmentation, image)
        im_H, im_W = image_shape = image.shape[:2]  # h, w

        # NOTE: scale camera intrinsic if necessary ================================
        scale_x = im_W / im_W_ori
        scale_y = im_H / im_H_ori  # NOTE: generally scale_x should be equal to scale_y
        if "cam" in dataset_dict:
            if im_W != im_W_ori or im_H != im_H_ori:
                dataset_dict["cam"][0] *= scale_x
                dataset_dict["cam"][1] *= scale_y
            K = dataset_dict["cam"].astype("float32")
            dataset_dict["cam"] = torch.as_tensor(K)

        input_res = cfg.MODEL.CDPN.BACKBONE.INPUT_RES
        out_res = cfg.MODEL.CDPN.BACKBONE.OUTPUT_RES

        # CHW -> HWC
        coord_2d = get_2d_coord_np(im_W, im_H, low=0,
                                   high=1).transpose(1, 2, 0)

        #################################################################################
        if self.split != "train":
            # don't load annotations at test time
            test_bbox_type = cfg.TEST.TEST_BBOX_TYPE
            if test_bbox_type == "gt":
                bbox_key = "bbox"
            else:
                bbox_key = f"bbox_{test_bbox_type}"
            assert not self.flatten, "Do not use flattened dicts for test!"
            # here get batched rois
            roi_infos = {}
            # yapf: disable
            roi_keys = ["scene_im_id", "file_name", "cam", "im_H", "im_W",
                        "roi_img", "inst_id", "roi_coord_2d", "roi_cls", "score", "roi_extent",
                         bbox_key, "bbox_mode", "bbox_center", "roi_wh",
                         "scale", "resize_ratio", "model_info",
                        ]
            for _key in roi_keys:
                roi_infos[_key] = []
            # yapf: enable
            # TODO: how to handle image without detections
            #   filter those when load annotations or detections, implement a function for this
            # "annotations" means detections
            for inst_i, inst_infos in enumerate(dataset_dict["annotations"]):
                # inherent image-level infos
                roi_infos["scene_im_id"].append(dataset_dict["scene_im_id"])
                roi_infos["file_name"].append(dataset_dict["file_name"])
                roi_infos["im_H"].append(im_H)
                roi_infos["im_W"].append(im_W)
                roi_infos["cam"].append(dataset_dict["cam"].cpu().numpy())

                # roi-level infos
                roi_infos["inst_id"].append(inst_i)
                roi_infos["model_info"].append(inst_infos["model_info"])

                roi_cls = inst_infos["category_id"]
                roi_infos["roi_cls"].append(roi_cls)
                roi_infos["score"].append(inst_infos["score"])

                # extent
                roi_extent = self._get_extents(dataset_name)[roi_cls]
                roi_infos["roi_extent"].append(roi_extent)

                bbox = BoxMode.convert(inst_infos[bbox_key],
                                       inst_infos["bbox_mode"],
                                       BoxMode.XYXY_ABS)
                bbox = np.array(transforms.apply_box([bbox])[0])
                roi_infos[bbox_key].append(bbox)
                roi_infos["bbox_mode"].append(BoxMode.XYXY_ABS)
                x1, y1, x2, y2 = bbox
                bbox_center = np.array([0.5 * (x1 + x2), 0.5 * (y1 + y2)])
                bw = max(x2 - x1, 1)
                bh = max(y2 - y1, 1)
                scale = max(bh, bw) * cfg.INPUT.DZI_PAD_SCALE
                scale = min(scale, max(im_H, im_W)) * 1.0

                roi_infos["bbox_center"].append(bbox_center.astype("float32"))
                roi_infos["scale"].append(scale)
                roi_infos["roi_wh"].append(np.array([bw, bh],
                                                    dtype=np.float32))
                roi_infos["resize_ratio"].append(out_res / scale)

                # CHW, float32 tensor
                # roi_image
                roi_img = crop_resize_by_warp_affine(
                    image,
                    bbox_center,
                    scale,
                    input_res,
                    interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)

                roi_img = self.normalize_image(cfg, roi_img)
                roi_infos["roi_img"].append(roi_img.astype("float32"))

                # roi_coord_2d
                roi_coord_2d = crop_resize_by_warp_affine(
                    coord_2d,
                    bbox_center,
                    scale,
                    out_res,
                    interpolation=cv2.INTER_LINEAR).transpose(2, 0,
                                                              1)  # HWC -> CHW
                roi_infos["roi_coord_2d"].append(
                    roi_coord_2d.astype("float32"))

            for _key in roi_keys:
                if _key in ["roi_img", "roi_coord_2d"]:
                    dataset_dict[_key] = torch.as_tensor(
                        roi_infos[_key]).contiguous()
                elif _key in ["model_info", "scene_im_id", "file_name"]:
                    # can not convert to tensor
                    dataset_dict[_key] = roi_infos[_key]
                else:
                    dataset_dict[_key] = torch.tensor(roi_infos[_key])

            return dataset_dict
        #######################################################################################
        # NOTE: currently assume flattened dicts for train
        assert self.flatten, "Only support flattened dicts for train now"
        inst_infos = dataset_dict.pop("inst_infos")
        dataset_dict["roi_cls"] = roi_cls = inst_infos["category_id"]

        # extent
        roi_extent = self._get_extents(dataset_name)[roi_cls]
        dataset_dict["roi_extent"] = torch.tensor(roi_extent,
                                                  dtype=torch.float32)

        # load xyz =======================================================
        xyz_info = mmcv.load(inst_infos["xyz_path"])
        x1, y1, x2, y2 = xyz_info["xyxy"]
        # float16 does not affect performance (classification/regresion)
        xyz_crop = xyz_info["xyz_crop"]
        xyz = np.zeros((im_H, im_W, 3), dtype=np.float32)
        xyz[y1:y2 + 1, x1:x2 + 1, :] = xyz_crop
        # NOTE: full mask
        mask_obj = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) |
                    (xyz[:, :, 2] != 0)).astype(np.bool).astype(np.float32)
        if cfg.INPUT.SMOOTH_XYZ:
            xyz = self.smooth_xyz(xyz)

        if cfg.TRAIN.VIS:
            xyz = self.smooth_xyz(xyz)

        # override bbox info using xyz_infos
        inst_infos["bbox"] = [x1, y1, x2, y2]
        inst_infos["bbox_mode"] = BoxMode.XYXY_ABS

        # USER: Implement additional transformations if you have other types of data
        # inst_infos.pop("segmentation")  # NOTE: use mask from xyz
        anno = transform_instance_annotations(inst_infos,
                                              transforms,
                                              image_shape,
                                              keypoint_hflip_indices=None)

        # augment bbox ===================================================
        bbox_xyxy = anno["bbox"]
        bbox_center, scale = self.aug_bbox(cfg, bbox_xyxy, im_H, im_W)
        bw = max(bbox_xyxy[2] - bbox_xyxy[0], 1)
        bh = max(bbox_xyxy[3] - bbox_xyxy[1], 1)

        # CHW, float32 tensor
        ## roi_image ------------------------------------
        roi_img = crop_resize_by_warp_affine(
            image,
            bbox_center,
            scale,
            input_res,
            interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)

        roi_img = self.normalize_image(cfg, roi_img)

        # roi_coord_2d ----------------------------------------------------
        roi_coord_2d = crop_resize_by_warp_affine(
            coord_2d,
            bbox_center,
            scale,
            out_res,
            interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1)

        ## roi_mask ---------------------------------------
        # (mask_trunc < mask_visib < mask_obj)
        mask_visib = anno["segmentation"].astype("float32") * mask_obj
        if mask_trunc is None:
            mask_trunc = mask_visib
        else:
            mask_trunc = mask_visib * mask_trunc.astype("float32")

        if cfg.TRAIN.VIS:
            mask_xyz_interp = cv2.INTER_LINEAR
        else:
            mask_xyz_interp = cv2.INTER_NEAREST

        # maybe truncated mask (true mask for rgb)
        roi_mask_trunc = crop_resize_by_warp_affine(
            mask_trunc[:, :, None],
            bbox_center,
            scale,
            out_res,
            interpolation=mask_xyz_interp)

        # use original visible mask to calculate xyz loss (try full obj mask?)
        roi_mask_visib = crop_resize_by_warp_affine(
            mask_visib[:, :, None],
            bbox_center,
            scale,
            out_res,
            interpolation=mask_xyz_interp)

        roi_mask_obj = crop_resize_by_warp_affine(
            mask_obj[:, :, None],
            bbox_center,
            scale,
            out_res,
            interpolation=mask_xyz_interp)

        ## roi_xyz ----------------------------------------------------
        roi_xyz = crop_resize_by_warp_affine(xyz,
                                             bbox_center,
                                             scale,
                                             out_res,
                                             interpolation=mask_xyz_interp)

        # region label
        if r_head_cfg.NUM_REGIONS > 1:
            fps_points = self._get_fps_points(dataset_name)[roi_cls]
            roi_region = xyz_to_region(roi_xyz, fps_points)  # HW
            dataset_dict["roi_region"] = torch.as_tensor(
                roi_region.astype(np.int32)).contiguous()

        roi_xyz = roi_xyz.transpose(2, 0, 1)  # HWC-->CHW
        # normalize xyz to [0, 1] using extent
        roi_xyz[0] = roi_xyz[0] / roi_extent[0] + 0.5
        roi_xyz[1] = roi_xyz[1] / roi_extent[1] + 0.5
        roi_xyz[2] = roi_xyz[2] / roi_extent[2] + 0.5

        if ("CE" in r_head_cfg.XYZ_LOSS_TYPE) or (
                "cls" in cfg.MODEL.CDPN.NAME):  # convert target to int for cls
            # assume roi_xyz has been normalized in [0, 1]
            roi_xyz_bin = np.zeros_like(roi_xyz)
            roi_x_norm = roi_xyz[0]
            roi_x_norm[roi_x_norm < 0] = 0  # clip
            roi_x_norm[roi_x_norm > 0.999999] = 0.999999
            # [0, BIN-1]
            roi_xyz_bin[0] = np.asarray(roi_x_norm * r_head_cfg.XYZ_BIN,
                                        dtype=np.uint8)

            roi_y_norm = roi_xyz[1]
            roi_y_norm[roi_y_norm < 0] = 0
            roi_y_norm[roi_y_norm > 0.999999] = 0.999999
            roi_xyz_bin[1] = np.asarray(roi_y_norm * r_head_cfg.XYZ_BIN,
                                        dtype=np.uint8)

            roi_z_norm = roi_xyz[2]
            roi_z_norm[roi_z_norm < 0] = 0
            roi_z_norm[roi_z_norm > 0.999999] = 0.999999
            roi_xyz_bin[2] = np.asarray(roi_z_norm * r_head_cfg.XYZ_BIN,
                                        dtype=np.uint8)

            # the last bin is for bg
            roi_masks = {
                "trunc": roi_mask_trunc,
                "visib": roi_mask_visib,
                "obj": roi_mask_obj
            }
            roi_mask_xyz = roi_masks[r_head_cfg.XYZ_LOSS_MASK_GT]
            roi_xyz_bin[0][roi_mask_xyz == 0] = r_head_cfg.XYZ_BIN
            roi_xyz_bin[1][roi_mask_xyz == 0] = r_head_cfg.XYZ_BIN
            roi_xyz_bin[2][roi_mask_xyz == 0] = r_head_cfg.XYZ_BIN

            if "CE" in r_head_cfg.XYZ_LOSS_TYPE:
                dataset_dict["roi_xyz_bin"] = torch.as_tensor(
                    roi_xyz_bin.astype("uint8")).contiguous()
            if "/" in r_head_cfg.XYZ_LOSS_TYPE and len(
                    r_head_cfg.XYZ_LOSS_TYPE.split("/")[1]) > 0:
                dataset_dict["roi_xyz"] = torch.as_tensor(
                    roi_xyz.astype("float32")).contiguous()
        else:
            dataset_dict["roi_xyz"] = torch.as_tensor(
                roi_xyz.astype("float32")).contiguous()

        # pose targets ----------------------------------------------------------------------
        pose = inst_infos["pose"]
        allo_pose = egocentric_to_allocentric(pose)
        quat = inst_infos["quat"]
        allo_quat = mat2quat(allo_pose[:3, :3])

        # ====== actually not needed ==========
        if pnp_net_cfg.ROT_TYPE == "allo_quat":
            dataset_dict["allo_quat"] = torch.as_tensor(
                allo_quat.astype("float32"))
        elif pnp_net_cfg.ROT_TYPE == "ego_quat":
            dataset_dict["ego_quat"] = torch.as_tensor(quat.astype("float32"))
        # rot6d
        elif pnp_net_cfg.ROT_TYPE == "ego_rot6d":
            dataset_dict["ego_rot6d"] = torch.as_tensor(
                mat_to_ortho6d_np(pose[:3, :3].astype("float32")))
        elif pnp_net_cfg.ROT_TYPE == "allo_rot6d":
            dataset_dict["allo_rot6d"] = torch.as_tensor(
                mat_to_ortho6d_np(allo_pose[:3, :3].astype("float32")))
        # log quat
        elif pnp_net_cfg.ROT_TYPE == "ego_log_quat":
            dataset_dict["ego_log_quat"] = quaternion_lf.qlog(
                torch.as_tensor(quat.astype("float32"))[None])[0]
        elif pnp_net_cfg.ROT_TYPE == "allo_log_quat":
            dataset_dict["allo_log_quat"] = quaternion_lf.qlog(
                torch.as_tensor(allo_quat.astype("float32"))[None])[0]
        # lie vec
        elif pnp_net_cfg.ROT_TYPE == "ego_lie_vec":
            dataset_dict["ego_lie_vec"] = lie_algebra.rot_to_lie_vec(
                torch.as_tensor(pose[:3, :3].astype("float32")[None]))[0]
        elif pnp_net_cfg.ROT_TYPE == "allo_lie_vec":
            dataset_dict["allo_lie_vec"] = lie_algebra.rot_to_lie_vec(
                torch.as_tensor(allo_pose[:3, :3].astype("float32"))[None])[0]
        else:
            raise ValueError(f"Unknown rot type: {pnp_net_cfg.ROT_TYPE}")
        dataset_dict["ego_rot"] = torch.as_tensor(
            pose[:3, :3].astype("float32"))
        dataset_dict["trans"] = torch.as_tensor(
            inst_infos["trans"].astype("float32"))

        dataset_dict["roi_points"] = torch.as_tensor(
            self._get_model_points(dataset_name)[roi_cls].astype("float32"))
        dataset_dict["sym_info"] = self._get_sym_infos(dataset_name)[roi_cls]

        dataset_dict["roi_img"] = torch.as_tensor(
            roi_img.astype("float32")).contiguous()
        dataset_dict["roi_coord_2d"] = torch.as_tensor(
            roi_coord_2d.astype("float32")).contiguous()

        dataset_dict["roi_mask_trunc"] = torch.as_tensor(
            roi_mask_trunc.astype("float32")).contiguous()
        dataset_dict["roi_mask_visib"] = torch.as_tensor(
            roi_mask_visib.astype("float32")).contiguous()
        dataset_dict["roi_mask_obj"] = torch.as_tensor(
            roi_mask_obj.astype("float32")).contiguous()

        dataset_dict["bbox_center"] = torch.as_tensor(bbox_center,
                                                      dtype=torch.float32)
        dataset_dict["scale"] = scale
        dataset_dict["bbox"] = anno["bbox"]  # NOTE: original bbox
        dataset_dict["roi_wh"] = torch.as_tensor(
            np.array([bw, bh], dtype=np.float32))
        dataset_dict["resize_ratio"] = resize_ratio = out_res / scale
        z_ratio = inst_infos["trans"][2] / resize_ratio
        obj_center = anno["centroid_2d"]
        delta_c = obj_center - bbox_center
        dataset_dict["trans_ratio"] = torch.as_tensor(
            [delta_c[0] / bw, delta_c[1] / bh, z_ratio]).to(torch.float32)
        return dataset_dict
示例#2
0
    def __call__(self):
        """
        Load light-weight instance annotations of all images into a list of dicts in Detectron2 format.
        Do not load heavy data into memory in this file,
        since we will load the annotations of all images into memory.
        """
        # cache the dataset_dicts to avoid loading masks from files
        hashed_file_name = hashlib.md5((
            "".join([str(fn) for fn in self.objs]) +
            "dataset_dicts_{}_{}_{}_{}_{}".format(
                self.name, self.dataset_root, self.with_masks, self.with_depth,
                osp.abspath(__file__))).encode("utf-8")).hexdigest()
        cache_path = osp.join(
            self.dataset_root,
            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name))

        if osp.exists(cache_path) and self.use_cache:
            logger.info("load cached dataset dicts from {}".format(cache_path))
            return mmcv.load(cache_path)

        t_start = time.perf_counter()

        logger.info("loading dataset dicts: {}".format(self.name))
        self.num_instances_without_valid_segmentation = 0
        self.num_instances_without_valid_box = 0

        dataset_dicts = [
        ]  #######################################################
        im_id_global = 0

        if True:
            targets = mmcv.load(self.ann_file)
            scene_im_ids = [(item["scene_id"], item["im_id"])
                            for item in targets]
            scene_im_ids = sorted(list(set(scene_im_ids)))

            # load infos for each scene
            gt_dicts = {}
            gt_info_dicts = {}
            cam_dicts = {}
            for scene_id, im_id in scene_im_ids:
                scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
                if scene_id not in gt_dicts:
                    gt_dicts[scene_id] = mmcv.load(
                        osp.join(scene_root, 'scene_gt.json'))
                if scene_id not in gt_info_dicts:
                    gt_info_dicts[scene_id] = mmcv.load(
                        osp.join(scene_root,
                                 'scene_gt_info.json'))  # bbox_obj, bbox_visib
                if scene_id not in cam_dicts:
                    cam_dicts[scene_id] = mmcv.load(
                        osp.join(scene_root, "scene_camera.json"))

            for scene_id, im_id in tqdm(scene_im_ids):
                str_im_id = str(im_id)
                scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
                rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(im_id)
                assert osp.exists(rgb_path), rgb_path

                depth_path = osp.join(scene_root,
                                      "depth/{:06d}.png".format(im_id))

                scene_id = int(rgb_path.split('/')[-3])

                cam = np.array(cam_dicts[scene_id][str_im_id]['cam_K'],
                               dtype=np.float32).reshape(3, 3)
                depth_factor = 1000. / cam_dicts[scene_id][str_im_id][
                    'depth_scale']
                record = {
                    "dataset_name": self.name,
                    'file_name': osp.relpath(rgb_path, PROJ_ROOT),
                    'depth_file': osp.relpath(depth_path, PROJ_ROOT),
                    "depth_factor": depth_factor,
                    'height': self.height,
                    'width': self.width,
                    'image_id':
                    im_id_global,  # unique image_id in the dataset, for coco evaluation
                    "scene_im_id": "{}/{}".format(scene_id,
                                                  im_id),  # for evaluation
                    "cam": cam,
                    "img_type": 'real'
                }
                im_id_global += 1
                insts = []
                for anno_i, anno in enumerate(gt_dicts[scene_id][str_im_id]):
                    obj_id = anno['obj_id']
                    if ref.tudl.id2obj[obj_id] not in self.select_objs:
                        continue
                    cur_label = self.cat2label[obj_id]  # 0-based label
                    R = np.array(anno['cam_R_m2c'],
                                 dtype='float32').reshape(3, 3)
                    t = np.array(anno['cam_t_m2c'], dtype='float32') / 1000.0
                    pose = np.hstack([R, t.reshape(3, 1)])
                    quat = mat2quat(R).astype('float32')
                    allo_q = mat2quat(egocentric_to_allocentric(pose)
                                      [:3, :3]).astype('float32')

                    proj = (record["cam"] @ t.T).T
                    proj = proj[:2] / proj[2]

                    bbox_visib = gt_info_dicts[scene_id][str_im_id][anno_i][
                        'bbox_visib']
                    bbox_obj = gt_info_dicts[scene_id][str_im_id][anno_i][
                        'bbox_obj']
                    x1, y1, w, h = bbox_visib
                    if self.filter_invalid:
                        if h <= 1 or w <= 1:
                            self.num_instances_without_valid_box += 1
                            continue

                    mask_file = osp.join(
                        scene_root,
                        "mask/{:06d}_{:06d}.png".format(im_id, anno_i))
                    mask_visib_file = osp.join(
                        scene_root,
                        "mask_visib/{:06d}_{:06d}.png".format(im_id, anno_i))
                    assert osp.exists(mask_file), mask_file
                    assert osp.exists(mask_visib_file), mask_visib_file
                    # load mask visib  TODO: load both mask_visib and mask_full
                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
                    area = mask_single.sum()
                    if area < 3:  # filter out too small or nearly invisible instances
                        self.num_instances_without_valid_segmentation += 1
                        continue
                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
                    inst = {
                        'category_id': cur_label,  # 0-based label
                        'bbox':
                        bbox_visib,  # TODO: load both bbox_obj and bbox_visib
                        'bbox_mode': BoxMode.XYWH_ABS,
                        'pose': pose,
                        "quat": quat,
                        "trans": t,
                        "allo_quat": allo_q,
                        "centroid_2d": proj,  # absolute (cx, cy)
                        "segmentation": mask_rle,
                        "mask_full_file":
                        mask_file,  # TODO: load as mask_full, rle
                    }

                    insts.append(inst)
                if len(insts) == 0:  # filter im without anno
                    continue
                record['annotations'] = insts
                dataset_dicts.append(record)

        if self.num_instances_without_valid_segmentation > 0:
            logger.warning(
                "Filtered out {} instances without valid segmentation. "
                "There might be issues in your dataset generation process.".
                format(self.num_instances_without_valid_segmentation))
        if self.num_instances_without_valid_box > 0:
            logger.warning(
                "Filtered out {} instances without valid box. "
                "There might be issues in your dataset generation process.".
                format(self.num_instances_without_valid_box))
        ##########################################################################
        if self.num_to_load > 0:
            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
            dataset_dicts = dataset_dicts[:self.num_to_load]
        logger.info("loaded dataset dicts, num_images: {}, using {}s".format(
            len(dataset_dicts),
            time.perf_counter() - t_start))

        mmcv.dump(dataset_dicts, cache_path, protocol=4)
        logger.info("Dumped dataset_dicts to {}".format(cache_path))
        return dataset_dicts
示例#3
0
    def __call__(self):
        """
        Load light-weight instance annotations of all images into a list of dicts in Detectron2 format.
        Do not load heavy data into memory in this file,
        since we will load the annotations of all images into memory.
        """
        # cache the dataset_dicts to avoid loading masks from files
        hashed_file_name = hashlib.md5(
            ("".join([str(fn) for fn in self.objs]) +
             "dataset_dicts_{}_{}_{}_{}_{}_{}".format(
                 self.name, self.dataset_root,
                 self.with_masks, self.with_depth, self.with_xyz,
                 osp.abspath(__file__))).encode("utf-8")).hexdigest()
        cache_path = osp.join(
            self.dataset_root,
            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name))

        if osp.exists(cache_path) and self.use_cache:
            logger.info("load cached dataset dicts from {}".format(cache_path))
            return mmcv.load(cache_path)

        t_start = time.perf_counter()
        dataset_dicts = []
        self.num_instances_without_valid_segmentation = 0
        self.num_instances_without_valid_box = 0
        logger.info("loading dataset dicts: {}".format(self.name))
        # it is slow because of loading and converting masks to rle

        for scene in self.scenes:
            scene_id = int(scene)
            scene_root = osp.join(self.dataset_root, scene)

            gt_dict = mmcv.load(osp.join(scene_root, 'scene_gt.json'))
            gt_info_dict = mmcv.load(osp.join(scene_root,
                                              'scene_gt_info.json'))
            cam_dict = mmcv.load(osp.join(scene_root, 'scene_camera.json'))

            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
                int_im_id = int(str_im_id)
                rgb_path = osp.join(scene_root,
                                    "rgb/{:06d}.jpg").format(int_im_id)
                assert osp.exists(rgb_path), rgb_path

                depth_path = osp.join(scene_root,
                                      "depth/{:06d}.png".format(int_im_id))
                K = np.array(cam_dict[str_im_id]['cam_K'],
                             dtype=np.float32).reshape(3, 3)
                depth_factor = 1000.0 / cam_dict[str_im_id][
                    'depth_scale']  # 10000

                record = {
                    "dataset_name": self.name,
                    'file_name': osp.relpath(rgb_path, PROJ_ROOT),
                    'depth_file': osp.relpath(depth_path, PROJ_ROOT),
                    'height': self.height,
                    'width': self.width,
                    'image_id': int_im_id,
                    "scene_im_id": "{}/{}".format(scene_id,
                                                  int_im_id),  # for evaluation
                    "cam": K,
                    "depth_factor": depth_factor,
                    "img_type": 'syn_pbr'  # NOTE: has background
                }
                insts = []
                for anno_i, anno in enumerate(gt_dict[str_im_id]):
                    obj_id = anno['obj_id']
                    if obj_id not in self.cat_ids:
                        continue
                    cur_label = self.cat2label[obj_id]  # 0-based label
                    R = np.array(anno['cam_R_m2c'],
                                 dtype='float32').reshape(3, 3)
                    t = np.array(anno['cam_t_m2c'], dtype='float32') / 1000.0
                    pose = np.hstack([R, t.reshape(3, 1)])
                    quat = mat2quat(R).astype('float32')
                    allo_q = mat2quat(egocentric_to_allocentric(pose)
                                      [:3, :3]).astype('float32')

                    proj = (record["cam"] @ t.T).T
                    proj = proj[:2] / proj[2]

                    bbox_visib = gt_info_dict[str_im_id][anno_i]['bbox_visib']
                    bbox_obj = gt_info_dict[str_im_id][anno_i]['bbox_obj']
                    x1, y1, w, h = bbox_visib
                    if self.filter_invalid:
                        if h <= 1 or w <= 1:
                            self.num_instances_without_valid_box += 1
                            continue

                    mask_file = osp.join(
                        scene_root,
                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i))
                    mask_visib_file = osp.join(
                        scene_root, "mask_visib/{:06d}_{:06d}.png".format(
                            int_im_id, anno_i))
                    assert osp.exists(mask_file), mask_file
                    assert osp.exists(mask_visib_file), mask_visib_file
                    # load mask visib  TODO: load both mask_visib and mask_full
                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
                    area = mask_single.sum()
                    if area < 3:  # filter out too small or nearly invisible instances
                        self.num_instances_without_valid_segmentation += 1
                        continue
                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)

                    inst = {
                        'category_id': cur_label,  # 0-based label
                        'bbox':
                        bbox_visib,  # TODO: load both bbox_obj and bbox_visib
                        'bbox_mode': BoxMode.XYWH_ABS,
                        'pose': pose,
                        "quat": quat,
                        "trans": t,
                        "allo_quat": allo_q,
                        "centroid_2d": proj,  # absolute (cx, cy)
                        "segmentation": mask_rle,
                        "mask_full_file":
                        mask_file,  # TODO: load as mask_full, rle
                    }
                    if self.with_xyz:
                        xyz_crop_path = mask_file.replace(
                            "/mask/", "/xyz_crop/").replace(".png", ".pkl")
                        assert osp.exists(xyz_crop_path), xyz_crop_path
                        inst["xyz_crop_path"] = xyz_crop_path

                    insts.append(inst)
                if len(insts) == 0:  # filter im without anno
                    continue
                record['annotations'] = insts
                dataset_dicts.append(record)

        if self.num_instances_without_valid_segmentation > 0:
            logger.warning(
                "Filtered out {} instances without valid segmentation. "
                "There might be issues in your dataset generation process.".
                format(self.num_instances_without_valid_segmentation))
        if self.num_instances_without_valid_box > 0:
            logger.warning(
                "Filtered out {} instances without valid box. "
                "There might be issues in your dataset generation process.".
                format(self.num_instances_without_valid_box))
        ##########################
        if self.num_to_load > 0:
            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
            dataset_dicts = dataset_dicts[:self.num_to_load]
        logger.info("loaded {} dataset dicts, using {}s".format(
            len(dataset_dicts),
            time.perf_counter() - t_start))

        mkdir_p(osp.dirname(cache_path))
        mmcv.dump(dataset_dicts, cache_path, protocol=4)
        logger.info("Dumped dataset_dicts to {}".format(cache_path))
        return dataset_dicts
示例#4
0
    def _load_from_idx_file(self, idx_file, image_root):
        """
        idx_file: the scene/image ids
        image_root/scene contains:
            scene_gt.json
            scene_gt_info.json
            scene_camera.json
        """
        scene_gt_dicts = {}
        scene_gt_info_dicts = {}
        scene_cam_dicts = {}
        scene_im_ids = []  # store tuples of (scene_id, im_id)
        with open(idx_file, 'r') as f:
            for line in f:
                line_split = line.strip('\r\n').split('/')
                scene_id = int(line_split[0])
                im_id = int(line_split[1])
                scene_im_ids.append((scene_id, im_id))
                if scene_id not in scene_gt_dicts:
                    scene_gt_file = osp.join(image_root, f'{scene_id:06d}/scene_gt.json')
                    assert osp.exists(scene_gt_file), scene_gt_file
                    scene_gt_dicts[scene_id] = mmcv.load(scene_gt_file)

                if scene_id not in scene_gt_info_dicts:
                    scene_gt_info_file = osp.join(image_root, f'{scene_id:06d}/scene_gt_info.json')
                    assert osp.exists(scene_gt_info_file), scene_gt_info_file
                    scene_gt_info_dicts[scene_id] = mmcv.load(scene_gt_info_file)

                if scene_id not in scene_cam_dicts:
                    scene_cam_file = osp.join(image_root, f'{scene_id:06d}/scene_camera.json')
                    assert osp.exists(scene_cam_file), scene_cam_file
                    scene_cam_dicts[scene_id] = mmcv.load(scene_cam_file)
        ######################################################
        scene_im_ids = sorted(scene_im_ids)  # sort to make it reproducible
        dataset_dicts = []

        num_instances_without_valid_segmentation = 0
        num_instances_without_valid_box = 0

        for (scene_id, im_id) in tqdm(scene_im_ids):
            rgb_path = osp.join(image_root, f'{scene_id:06d}/rgb/{im_id:06d}.png')
            assert osp.exists(rgb_path), rgb_path
            # for ycbv/tless, load cam K from image infos
            cam_anno = np.array(scene_cam_dicts[scene_id][str(im_id)]["cam_K"], dtype="float32").reshape(3, 3)
            # dprint(record['cam'])
            if '/train_synt/' in rgb_path:
                img_type = 'syn'
            else:
                img_type = 'real'
            record = {
                "dataset_name": self.name,
                'file_name': osp.relpath(rgb_path, PROJ_ROOT),
                'height': self.height,
                'width': self.width,
                'image_id': self._unique_im_id,
                "scene_im_id": "{}/{}".format(scene_id, im_id),  # for evaluation
                "cam": cam_anno,  # self.cam,
                "img_type": img_type
            }

            if self.with_depth:
                depth_file = osp.join(image_root, f'{scene_id:06d}/depth/{im_id:06d}.png')
                assert osp.exists(depth_file), depth_file
                record["depth_file"] = osp.relpath(depth_file, PROJ_ROOT)

            insts = []
            anno_dict_list = scene_gt_dicts[scene_id][str(im_id)]
            info_dict_list = scene_gt_info_dicts[scene_id][str(im_id)]
            for anno_i, anno in enumerate(anno_dict_list):
                info = info_dict_list[anno_i]
                obj_id = anno['obj_id']
                if obj_id not in self.cat_ids:
                    continue
                # 0-based label now
                cur_label = self.cat2label[obj_id]
                ################ pose ###########################
                R = np.array(anno['cam_R_m2c'], dtype='float32').reshape(3, 3)
                trans = np.array(anno['cam_t_m2c'], dtype='float32') / 1000.0  # mm->m
                pose = np.hstack([R, trans.reshape(3, 1)])
                quat = mat2quat(pose[:3, :3])
                allo_q = mat2quat(egocentric_to_allocentric(pose)[:3, :3])

                ############# bbox ############################
                if True:
                    bbox = info['bbox_obj']
                    x1, y1, w, h = bbox
                    x2 = x1 + w
                    y2 = y1 + h
                    x1 = max(min(x1, self.width), 0)
                    y1 = max(min(y1, self.height), 0)
                    x2 = max(min(x2, self.width), 0)
                    y2 = max(min(y2, self.height), 0)
                    bbox = [x1, y1, x2, y2]
                if self.filter_invalid:
                    bw = bbox[2] - bbox[0]
                    bh = bbox[3] - bbox[1]
                    if bh <= 1 or bw <= 1:
                        num_instances_without_valid_box += 1
                        continue

                ############## mask #######################
                if self.with_masks:  # either list[list[float]] or dict(RLE)
                    mask_visib_file = osp.join(image_root, f'{scene_id:06d}/mask_visib/{im_id:06d}_{anno_i:06d}.png')
                    assert osp.exists(mask_visib_file), mask_visib_file
                    mask = mmcv.imread(mask_visib_file, 'unchanged')
                    if mask.sum() < 1 and self.filter_invalid:
                        num_instances_without_valid_segmentation += 1
                        continue
                    mask_rle = binary_mask_to_rle(mask)

                    mask_full_file = osp.join(image_root, f'{scene_id:06d}/mask/{im_id:06d}_{anno_i:06d}.png')
                    assert osp.exists(mask_full_file), mask_full_file

                proj = (self.cam @ trans.T).T  # NOTE: use self.cam here
                proj = proj[:2] / proj[2]

                inst = {
                    'category_id': cur_label,  # 0-based label
                    'bbox': bbox,  # TODO: load both bbox_obj and bbox_visib
                    'bbox_mode': BoxMode.XYXY_ABS,
                    "quat": quat,
                    "trans": trans,
                    "allo_quat": allo_q,
                    "centroid_2d": proj,  # absolute (cx, cy)
                    "segmentation": mask_rle,
                    "mask_full_file": mask_full_file,  # TODO: load as mask_full, rle
                }

                insts.append(inst)

            if len(insts) == 0:  # and self.filter_invalid:
                continue
            record["annotations"] = insts
            dataset_dicts.append(record)
            self._unique_im_id += 1

        if num_instances_without_valid_segmentation > 0:
            logger.warn("Filtered out {} instances without valid segmentation. "
                        "There might be issues in your dataset generation process.".format(
                            num_instances_without_valid_segmentation))
        if num_instances_without_valid_box > 0:
            logger.warn(
                "Filtered out {} instances without valid box. "
                "There might be issues in your dataset generation process.".format(num_instances_without_valid_box))
        return dataset_dicts