Пример #1
0
def grid_show(ims, titles=None, row=1, col=3, dpi=200, save_path=None, title_fontsize=5, show=True):
    if row * col < len(ims):
        print('_____________row*col < len(ims)___________')
        col = int(np.ceil(len(ims) / row))
    if titles is not None:
        assert len(ims) == len(titles), "{} != {}".format(len(ims), len(titles))
    fig = plt.figure(dpi=dpi, figsize=plt.figaspect(row / float(col)))
    k = 0
    for i in range(row):
        for j in range(col):
            if k >= len(ims):
                break
            plt.subplot(row, col, k + 1)
            plt.axis('off')
            plt.imshow(ims[k])
            if titles is not None:
                # plt.title(titles[k], size=title_fontsize)
                plt.text(0.5, 1.08, titles[k],
                        horizontalalignment='center',
                        fontsize=title_fontsize,
                        transform=plt.gca().transAxes)
            k += 1

    # plt.tight_layout()
    if show:
        plt.show()
    else:
        if save_path is not None:
            mkdir_p(osp.dirname(save_path))
            plt.savefig(save_path)
    return fig
Пример #2
0
def vis_image_mask_plt(im, mask, dpi=200, color=None, outfile=None, show=True):
    if color is None:
        color_list = colormap(rgb=True) / 255
        mask_color_id = 0
        color_mask = color_list[mask_color_id % len(color_list), 0:3]
    else:
        color_mask = color_val(color)
    # cmap = plt.get_cmap('rainbow')

    fig = plt.figure(frameon=False)
    fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
    ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
    ax.axis("off")
    fig.add_axes(ax)
    ax.imshow(im[:, :, [2, 1, 0]])

    # show mask
    img = np.ones(im.shape)

    w_ratio = 0.4
    for c in range(3):
        color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
    for c in range(3):
        img[:, :, c] = color_mask[c]
    e = mask

    _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP,
                                        cv2.CHAIN_APPROX_NONE)

    for c in contour:
        polygon = Polygon(c.reshape((-1, 2)),
                          fill=True,
                          facecolor=color_mask,
                          edgecolor="w",
                          linewidth=1.2,
                          alpha=0.5)
        ax.add_patch(polygon)
    if outfile is not None:
        mkdir_p(os.path.dirname(outfile))
        fig.savefig(outfile, dpi=dpi)
        plt.close("all")
    if show:
        plt.show()
Пример #3
0
def create_logger(root_output_path, cfg, image_set, temp_flie=False):
    # set up logger
    mkdir_p(root_output_path)
    assert osp.exists(root_output_path), "{} does not exist".format(
        root_output_path)

    cfg_name = osp.basename(cfg).split(".")[0]
    config_output_path = osp.join(root_output_path, "{}".format(cfg_name))
    mkdir_p(config_output_path)

    image_sets = [iset for iset in image_set.split("+")]
    final_output_path = osp.join(config_output_path,
                                 "{}".format("_".join(image_sets)))
    mkdir_p(final_output_path)

    if temp_flie:
        log_prefix = "temp_{}".format(cfg_name)
    else:
        log_prefix = "{}".format(cfg_name)
    logger.set_logger_dir(final_output_path, action='k', prefix=log_prefix)
    logger.setLevel(logging.INFO)

    return final_output_path
Пример #4
0
    def __call__(self):
        """
        Load light-weight instance annotations of all images into a list of dicts in Detectron2 format.
        Do not load heavy data into memory in this file,
        since we will load the annotations of all images into memory.
        """
        # cache the dataset_dicts to avoid loading masks from files
        hashed_file_name = hashlib.md5(
            ("".join([str(fn) for fn in self.objs]) +
             "dataset_dicts_{}_{}_{}_{}_{}_{}".format(
                 self.name, self.dataset_root,
                 self.with_masks, self.with_depth, self.with_xyz,
                 osp.abspath(__file__))).encode("utf-8")).hexdigest()
        cache_path = osp.join(
            self.dataset_root,
            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name))

        if osp.exists(cache_path) and self.use_cache:
            logger.info("load cached dataset dicts from {}".format(cache_path))
            return mmcv.load(cache_path)

        t_start = time.perf_counter()
        dataset_dicts = []
        self.num_instances_without_valid_segmentation = 0
        self.num_instances_without_valid_box = 0
        logger.info("loading dataset dicts: {}".format(self.name))
        # it is slow because of loading and converting masks to rle

        for scene in self.scenes:
            scene_id = int(scene)
            scene_root = osp.join(self.dataset_root, scene)

            gt_dict = mmcv.load(osp.join(scene_root, 'scene_gt.json'))
            gt_info_dict = mmcv.load(osp.join(scene_root,
                                              'scene_gt_info.json'))
            cam_dict = mmcv.load(osp.join(scene_root, 'scene_camera.json'))

            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
                int_im_id = int(str_im_id)
                rgb_path = osp.join(scene_root,
                                    "rgb/{:06d}.jpg").format(int_im_id)
                assert osp.exists(rgb_path), rgb_path

                depth_path = osp.join(scene_root,
                                      "depth/{:06d}.png".format(int_im_id))
                K = np.array(cam_dict[str_im_id]['cam_K'],
                             dtype=np.float32).reshape(3, 3)
                depth_factor = 1000.0 / cam_dict[str_im_id][
                    'depth_scale']  # 10000

                record = {
                    "dataset_name": self.name,
                    'file_name': osp.relpath(rgb_path, PROJ_ROOT),
                    'depth_file': osp.relpath(depth_path, PROJ_ROOT),
                    'height': self.height,
                    'width': self.width,
                    'image_id': int_im_id,
                    "scene_im_id": "{}/{}".format(scene_id,
                                                  int_im_id),  # for evaluation
                    "cam": K,
                    "depth_factor": depth_factor,
                    "img_type": 'syn_pbr'  # NOTE: has background
                }
                insts = []
                for anno_i, anno in enumerate(gt_dict[str_im_id]):
                    obj_id = anno['obj_id']
                    if obj_id not in self.cat_ids:
                        continue
                    cur_label = self.cat2label[obj_id]  # 0-based label
                    R = np.array(anno['cam_R_m2c'],
                                 dtype='float32').reshape(3, 3)
                    t = np.array(anno['cam_t_m2c'], dtype='float32') / 1000.0
                    pose = np.hstack([R, t.reshape(3, 1)])
                    quat = mat2quat(R).astype('float32')
                    allo_q = mat2quat(egocentric_to_allocentric(pose)
                                      [:3, :3]).astype('float32')

                    proj = (record["cam"] @ t.T).T
                    proj = proj[:2] / proj[2]

                    bbox_visib = gt_info_dict[str_im_id][anno_i]['bbox_visib']
                    bbox_obj = gt_info_dict[str_im_id][anno_i]['bbox_obj']
                    x1, y1, w, h = bbox_visib
                    if self.filter_invalid:
                        if h <= 1 or w <= 1:
                            self.num_instances_without_valid_box += 1
                            continue

                    mask_file = osp.join(
                        scene_root,
                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i))
                    mask_visib_file = osp.join(
                        scene_root, "mask_visib/{:06d}_{:06d}.png".format(
                            int_im_id, anno_i))
                    assert osp.exists(mask_file), mask_file
                    assert osp.exists(mask_visib_file), mask_visib_file
                    # load mask visib  TODO: load both mask_visib and mask_full
                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
                    area = mask_single.sum()
                    if area < 3:  # filter out too small or nearly invisible instances
                        self.num_instances_without_valid_segmentation += 1
                        continue
                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)

                    inst = {
                        'category_id': cur_label,  # 0-based label
                        'bbox':
                        bbox_visib,  # TODO: load both bbox_obj and bbox_visib
                        'bbox_mode': BoxMode.XYWH_ABS,
                        'pose': pose,
                        "quat": quat,
                        "trans": t,
                        "allo_quat": allo_q,
                        "centroid_2d": proj,  # absolute (cx, cy)
                        "segmentation": mask_rle,
                        "mask_full_file":
                        mask_file,  # TODO: load as mask_full, rle
                    }
                    if self.with_xyz:
                        xyz_crop_path = mask_file.replace(
                            "/mask/", "/xyz_crop/").replace(".png", ".pkl")
                        assert osp.exists(xyz_crop_path), xyz_crop_path
                        inst["xyz_crop_path"] = xyz_crop_path

                    insts.append(inst)
                if len(insts) == 0:  # filter im without anno
                    continue
                record['annotations'] = insts
                dataset_dicts.append(record)

        if self.num_instances_without_valid_segmentation > 0:
            logger.warning(
                "Filtered out {} instances without valid segmentation. "
                "There might be issues in your dataset generation process.".
                format(self.num_instances_without_valid_segmentation))
        if self.num_instances_without_valid_box > 0:
            logger.warning(
                "Filtered out {} instances without valid box. "
                "There might be issues in your dataset generation process.".
                format(self.num_instances_without_valid_box))
        ##########################
        if self.num_to_load > 0:
            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
            dataset_dicts = dataset_dicts[:self.num_to_load]
        logger.info("loaded {} dataset dicts, using {}s".format(
            len(dataset_dicts),
            time.perf_counter() - t_start))

        mkdir_p(osp.dirname(cache_path))
        mmcv.dump(dataset_dicts, cache_path, protocol=4)
        logger.info("Dumped dataset_dicts to {}".format(cache_path))
        return dataset_dicts
Пример #5
0
def pred_eval(config,
              predictor,
              test_data,
              imdb_test,
              vis=False,
              ignore_cache=None,
              logger=None,
              pairdb=None):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb_test: image database
    :param vis: controls visualization
    :param ignore_cache: ignore the saved cache file
    :param logger: the logger instance
    :return:
    """
    logger.info(imdb_test.result_path)
    logger.info("test iter size: {}".format(config.TEST.test_iter))
    pose_err_file = os.path.join(
        imdb_test.result_path,
        imdb_test.name + "_pose_iter{}.pkl".format(config.TEST.test_iter))
    if os.path.exists(pose_err_file) and not ignore_cache and not vis:
        with open(pose_err_file, "rb") as fid:
            if six.PY3:
                [all_rot_err, all_trans_err, all_poses_est,
                 all_poses_gt] = cPickle.load(fid, encoding="latin1")
            else:
                [all_rot_err, all_trans_err, all_poses_est,
                 all_poses_gt] = cPickle.load(fid)
        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt)
        pose_add_plots_dir = os.path.join(imdb_test.result_path, "add_plots")
        mkdir_p(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            "arp_2d_plots")
        mkdir_p(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir)
        return

    assert vis or not test_data.shuffle
    assert config.TEST.BATCH_PAIRS == 1
    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    num_pairs = len(pairdb)
    height = 480
    width = 640

    data_time, net_time, post_time = 0.0, 0.0, 0.0

    sum_EPE_all = 0.0
    num_inst_all = 0.0
    sum_EPE_viz = 0.0
    num_inst_viz = 0.0
    sum_EPE_vizbg = 0.0
    num_inst_vizbg = 0.0
    sum_PoseErr = [
        np.zeros((len(imdb_test.classes) + 1, 2))
        for batch_idx in range(config.TEST.test_iter)
    ]

    all_rot_err = [[[] for j in range(config.TEST.test_iter)]
                   for batch_idx in range(len(imdb_test.classes))
                   ]  # num_cls x test_iter
    all_trans_err = [[[] for j in range(config.TEST.test_iter)]
                     for batch_idx in range(len(imdb_test.classes))]

    all_poses_est = [[[] for j in range(config.TEST.test_iter)]
                     for batch_idx in range(len(imdb_test.classes))]
    all_poses_gt = [[[] for j in range(config.TEST.test_iter)]
                    for batch_idx in range(len(imdb_test.classes))]

    num_inst = np.zeros(len(imdb_test.classes) + 1)

    K = config.dataset.INTRINSIC_MATRIX
    if (config.TEST.test_iter > 1 or config.TEST.VISUALIZE) and True:
        print(
            "************* start setup render_glumpy environment... ******************"
        )
        if config.dataset.dataset.startswith("ModelNet"):
            from lib.render_glumpy.render_py_light_modelnet_multi import Render_Py_Light_ModelNet_Multi

            modelnet_root = config.modelnet_root
            texture_path = os.path.join(modelnet_root, "gray_texture.png")

            model_path_list = [
                os.path.join(config.dataset.model_dir,
                             "{}.obj".format(model_name))
                for model_name in config.dataset.class_name
            ]
            render_machine = Render_Py_Light_ModelNet_Multi(
                model_path_list,
                texture_path,
                K,
                width,
                height,
                config.dataset.ZNEAR,
                config.dataset.ZFAR,
                brightness_ratios=[0.7],
            )
        else:
            render_machine = Render_Py(
                config.dataset.model_dir,
                config.dataset.class_name,
                K,
                width,
                height,
                config.dataset.ZNEAR,
                config.dataset.ZFAR,
            )

        def render(render_machine, pose, cls_idx, K=None):
            if config.dataset.dataset.startswith("ModelNet"):
                idx = 2
                # generate random light_position
                if idx % 6 == 0:
                    light_position = [1, 0, 1]
                elif idx % 6 == 1:
                    light_position = [1, 1, 1]
                elif idx % 6 == 2:
                    light_position = [0, 1, 1]
                elif idx % 6 == 3:
                    light_position = [-1, 1, 1]
                elif idx % 6 == 4:
                    light_position = [-1, 0, 1]
                elif idx % 6 == 5:
                    light_position = [0, 0, 1]
                else:
                    raise Exception("???")
                light_position = np.array(light_position) * 0.5
                # inverse yz
                light_position[0] += pose[0, 3]
                light_position[1] -= pose[1, 3]
                light_position[2] -= pose[2, 3]

                colors = np.array([1, 1, 1])  # white light
                intensity = np.random.uniform(0.9, 1.1, size=(3, ))
                colors_randk = 0
                light_intensity = colors[colors_randk] * intensity

                # randomly choose a render machine
                rm_randk = 0  # random.randint(0, len(brightness_ratios) - 1)
                rgb_gl, depth_gl = render_machine.render(
                    cls_idx,
                    pose[:3, :3],
                    pose[:3, 3],
                    light_position,
                    light_intensity,
                    brightness_k=rm_randk,
                    r_type="mat",
                )
                rgb_gl = rgb_gl.astype("uint8")
            else:
                rgb_gl, depth_gl = render_machine.render(cls_idx,
                                                         pose[:3, :3],
                                                         pose[:, 3],
                                                         r_type="mat",
                                                         K=K)
                rgb_gl = rgb_gl.astype("uint8")
            return rgb_gl, depth_gl

        print(
            "***************setup render_glumpy environment succeed ******************"
        )

    if config.TEST.PRECOMPUTED_ICP:
        print("precomputed_ICP")
        config.TEST.test_iter = 1
        all_rot_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        all_trans_err = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]

        all_poses_est = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]
        all_poses_gt = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]

        xy_trans_err = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]
        z_trans_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        for idx in range(len(pairdb)):
            pose_path = pairdb[idx]["depth_rendered"][:-10] + "-pose_icp.txt"
            pose_rendered_update = np.loadtxt(pose_path, skiprows=1)
            pose_observed = pairdb[idx]["pose_observed"]
            r_dist_est, t_dist_est = calc_rt_dist_m(pose_rendered_update,
                                                    pose_observed)
            xy_dist = np.linalg.norm(pose_rendered_update[:2, -1] -
                                     pose_observed[:2, -1])
            z_dist = np.linalg.norm(pose_rendered_update[-1, -1] -
                                    pose_observed[-1, -1])
            print(
                "{}: r_dist_est: {}, t_dist_est: {}, xy_dist: {}, z_dist: {}".
                format(idx, r_dist_est, t_dist_est, xy_dist, z_dist))
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            # store poses estimation and gt
            all_poses_est[class_id][0].append(pose_rendered_update)
            all_poses_gt[class_id][0].append(pairdb[idx]["pose_observed"])
            all_rot_err[class_id][0].append(r_dist_est)
            all_trans_err[class_id][0].append(t_dist_est)
            xy_trans_err[class_id][0].append(xy_dist)
            z_trans_err[class_id][0].append(z_dist)
        all_rot_err = np.array(all_rot_err)
        all_trans_err = np.array(all_trans_err)
        print("rot = {} +/- {}".format(np.mean(all_rot_err[class_id][0]),
                                       np.std(all_rot_err[class_id][0])))
        print("trans = {} +/- {}".format(np.mean(all_trans_err[class_id][0]),
                                         np.std(all_trans_err[class_id][0])))
        num_list = all_trans_err[class_id][0]
        print("xyz: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))
        num_list = xy_trans_err[class_id][0]
        print("xy: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))
        num_list = z_trans_err[class_id][0]
        print("z: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))

        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt)
        pose_add_plots_dir = os.path.join(imdb_test.result_path,
                                          "add_plots_precomputed_ICP")
        mkdir_p(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            "arp_2d_plots_precomputed_ICP")
        mkdir_p(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir)
        return

    if config.TEST.BEFORE_ICP:
        print("before_ICP")
        config.TEST.test_iter = 1
        all_rot_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        all_trans_err = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]

        all_poses_est = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]
        all_poses_gt = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]

        xy_trans_err = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]
        z_trans_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        for idx in range(len(pairdb)):
            pose_path = pairdb[idx]["depth_rendered"][:-10] + "-pose.txt"
            pose_rendered_update = np.loadtxt(pose_path, skiprows=1)
            pose_observed = pairdb[idx]["pose_observed"]
            r_dist_est, t_dist_est = calc_rt_dist_m(pose_rendered_update,
                                                    pose_observed)
            xy_dist = np.linalg.norm(pose_rendered_update[:2, -1] -
                                     pose_observed[:2, -1])
            z_dist = np.linalg.norm(pose_rendered_update[-1, -1] -
                                    pose_observed[-1, -1])
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            # store poses estimation and gt
            all_poses_est[class_id][0].append(pose_rendered_update)
            all_poses_gt[class_id][0].append(pairdb[idx]["pose_observed"])
            all_rot_err[class_id][0].append(r_dist_est)
            all_trans_err[class_id][0].append(t_dist_est)
            xy_trans_err[class_id][0].append(xy_dist)
            z_trans_err[class_id][0].append(z_dist)

        all_trans_err = np.array(all_trans_err)
        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt)
        pose_add_plots_dir = os.path.join(imdb_test.result_path,
                                          "add_plots_before_ICP")
        mkdir_p(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            "arp_2d_plots_before_ICP")
        mkdir_p(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir)
        return

    # ------------------------------------------------------------------------------
    t_start = time.time()
    t = time.time()
    for idx, data_batch in enumerate(test_data):
        if np.sum(pairdb[idx]
                  ["pose_rendered"]) == -12:  # NO POINT VALID IN INIT POSE
            print(idx)
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            for pose_iter_idx in range(config.TEST.test_iter):
                all_poses_est[class_id][pose_iter_idx].append(
                    pairdb[idx]["pose_rendered"])
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]["pose_observed"])

                r_dist = 1000
                t_dist = 1000
                all_rot_err[class_id][pose_iter_idx].append(r_dist)
                all_trans_err[class_id][pose_iter_idx].append(t_dist)
                sum_PoseErr[pose_iter_idx][class_id, :] += np.array(
                    [r_dist, t_dist])
                sum_PoseErr[pose_iter_idx][-1, :] += np.array([r_dist, t_dist])
                # post process
            if idx % 50 == 0:
                logger.info(
                    "testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s".
                    format(
                        (idx + 1),
                        num_pairs,
                        data_time / ((idx + 1) * test_data.batch_size),
                        net_time / ((idx + 1) * test_data.batch_size),
                        post_time / ((idx + 1) * test_data.batch_size),
                    ))
            print("in test: NO POINT_VALID IN rendered")
            continue
        data_time += time.time() - t

        t = time.time()

        pose_rendered = pairdb[idx]["pose_rendered"]
        if np.sum(pose_rendered) == -12:
            print(idx)
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            num_inst[class_id] += 1
            num_inst[-1] += 1
            for pose_iter_idx in range(config.TEST.test_iter):
                all_poses_est[class_id][pose_iter_idx].append(pose_rendered)
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]["pose_observed"])

            # post process
            if idx % 50 == 0:
                logger.info(
                    "testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s".
                    format(
                        (idx + 1),
                        num_pairs,
                        data_time / ((idx + 1) * test_data.batch_size),
                        net_time / ((idx + 1) * test_data.batch_size),
                        post_time / ((idx + 1) * test_data.batch_size),
                    ))

            t = time.time()
            continue

        output_all = predictor.predict(data_batch)
        net_time += time.time() - t

        t = time.time()
        rst_iter = []
        for output in output_all:
            cur_rst = {}
            cur_rst["se3"] = np.squeeze(
                output["se3_output"].asnumpy()).astype("float32")

            if not config.TEST.FAST_TEST and config.network.PRED_FLOW:
                cur_rst["flow"] = np.squeeze(
                    output["flow_est_crop_output"].asnumpy().transpose(
                        (2, 3, 1, 0))).astype("float16")
            else:
                cur_rst["flow"] = None
            if config.network.PRED_MASK and config.TEST.UPDATE_MASK not in [
                    "init", "box_rendered"
            ]:
                mask_pred = np.squeeze(
                    output["mask_observed_pred_output"].asnumpy()).astype(
                        "float32")
                cur_rst["mask_pred"] = mask_pred

            rst_iter.append(cur_rst)

        post_time += time.time() - t
        # sample_ratio = 1  # 0.01
        for batch_idx in range(0, test_data.batch_size):
            # if config.TEST.VISUALIZE and not (r_dist>15 and t_dist>0.05):
            #     continue # 3388, 5326
            # calculate the flow error --------------------------------------------
            t = time.time()
            if config.network.PRED_FLOW and not config.TEST.FAST_TEST:
                # evaluate optical flow
                flow_gt = par_generate_gt(config, pairdb[idx])
                if config.network.PRED_FLOW:
                    all_diff = calc_EPE_one_pair(rst_iter[batch_idx], flow_gt,
                                                 "flow")
                sum_EPE_all += all_diff["epe_all"]
                num_inst_all += all_diff["num_all"]
                sum_EPE_viz += all_diff["epe_viz"]
                num_inst_viz += all_diff["num_viz"]
                sum_EPE_vizbg += all_diff["epe_vizbg"]
                num_inst_vizbg += all_diff["num_vizbg"]

            # calculate the se3 error ---------------------------------------------
            # evaluate se3 estimation
            pose_rendered = pairdb[idx]["pose_rendered"]
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            num_inst[class_id] += 1
            num_inst[-1] += 1
            post_time += time.time() - t

            # iterative refine se3 estimation --------------------------------------------------
            for pose_iter_idx in range(config.TEST.test_iter):
                t = time.time()
                pose_rendered_update = RT_transform(
                    pose_rendered,
                    rst_iter[0]["se3"][:-3],
                    rst_iter[0]["se3"][-3:],
                    config.dataset.trans_means,
                    config.dataset.trans_stds,
                    config.network.ROT_COORD,
                )

                # calculate error
                r_dist, t_dist = calc_rt_dist_m(pose_rendered_update,
                                                pairdb[idx]["pose_observed"])

                # store poses estimation and gt
                all_poses_est[class_id][pose_iter_idx].append(
                    pose_rendered_update)
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]["pose_observed"])

                all_rot_err[class_id][pose_iter_idx].append(r_dist)
                all_trans_err[class_id][pose_iter_idx].append(t_dist)
                sum_PoseErr[pose_iter_idx][class_id, :] += np.array(
                    [r_dist, t_dist])
                sum_PoseErr[pose_iter_idx][-1, :] += np.array([r_dist, t_dist])
                if config.TEST.VISUALIZE:
                    print("idx {}, iter {}: rError: {}, tError: {}".format(
                        idx + batch_idx, pose_iter_idx + 1, r_dist, t_dist))

                post_time += time.time() - t

                # # if more than one iteration
                if pose_iter_idx < (config.TEST.test_iter -
                                    1) or config.TEST.VISUALIZE:
                    t = time.time()
                    # get refined image
                    K_path = pairdb[idx]["image_observed"][:-10] + "-K.txt"
                    if os.path.exists(K_path):
                        K = np.loadtxt(K_path)
                    image_refined, depth_refined = render(
                        render_machine,
                        pose_rendered_update,
                        config.dataset.class_name.index(
                            pairdb[idx]["gt_class"]),
                        K=K,
                    )
                    image_refined = image_refined[:, :, :3]

                    # update minibatch
                    update_package = [{
                        "image_rendered": image_refined,
                        "src_pose": pose_rendered_update
                    }]
                    if config.network.INPUT_DEPTH:
                        update_package[0]["depth_rendered"] = depth_refined
                    if config.network.INPUT_MASK:
                        mask_rendered_refined = np.zeros(depth_refined.shape)
                        mask_rendered_refined[depth_refined > 0.2] = 1
                        update_package[0][
                            "mask_rendered"] = mask_rendered_refined
                        if config.network.PRED_MASK:
                            # init, box_rendered, mask_rendered, box_observed, mask_observed
                            if config.TEST.UPDATE_MASK == "box_rendered":
                                input_names = [
                                    blob_name[0]
                                    for blob_name in data_batch.provide_data[0]
                                ]
                                update_package[0][
                                    "mask_observed"] = np.squeeze(
                                        data_batch.data[0][input_names.index(
                                            "mask_rendered")].asnumpy()
                                        [batch_idx])  # noqa
                            elif config.TEST.UPDATE_MASK == "init":
                                pass
                            else:
                                raise Exception(
                                    "Unknown UPDATE_MASK type: {}".format(
                                        config.network.UPDATE_MASK))

                    pose_rendered = pose_rendered_update
                    data_batch = update_data_batch(config, data_batch,
                                                   update_package)

                    data_time += time.time() - t

                    # forward and get rst
                    if pose_iter_idx < config.TEST.test_iter - 1:
                        t = time.time()
                        output_all = predictor.predict(data_batch)
                        net_time += time.time() - t

                        t = time.time()
                        rst_iter = []
                        for output in output_all:
                            cur_rst = {}
                            if config.network.REGRESSOR_NUM == 1:
                                cur_rst["se3"] = np.squeeze(
                                    output["se3_output"].asnumpy()).astype(
                                        "float32")

                            if not config.TEST.FAST_TEST and config.network.PRED_FLOW:
                                cur_rst["flow"] = np.squeeze(
                                    output["flow_est_crop_output"].asnumpy().
                                    transpose((2, 3, 1, 0))).astype("float16")
                            else:
                                cur_rst["flow"] = None

                            if config.network.PRED_MASK and config.TEST.UPDATE_MASK not in [
                                    "init", "box_rendered"
                            ]:
                                mask_pred = np.squeeze(
                                    output["mask_observed_pred_output"].
                                    asnumpy()).astype("float32")
                                cur_rst["mask_pred"] = mask_pred

                            rst_iter.append(cur_rst)
                            post_time += time.time() - t

        # post process
        if idx % 50 == 0:
            logger.info(
                "testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s".
                format(
                    (idx + 1),
                    num_pairs,
                    data_time / ((idx + 1) * test_data.batch_size),
                    net_time / ((idx + 1) * test_data.batch_size),
                    post_time / ((idx + 1) * test_data.batch_size),
                ))

        t = time.time()

    all_rot_err = np.array(all_rot_err)
    all_trans_err = np.array(all_trans_err)

    # save inference results
    if not config.TEST.VISUALIZE:
        with open(pose_err_file, "wb") as f:
            logger.info("saving result cache to {}".format(pose_err_file))
            cPickle.dump(
                [all_rot_err, all_trans_err, all_poses_est, all_poses_gt],
                f,
                protocol=2)
            logger.info("done")

    if config.network.PRED_FLOW:
        logger.info("evaluate flow:")
        logger.info("EPE all: {}".format(sum_EPE_all / max(num_inst_all, 1.0)))
        logger.info("EPE ignore unvisible: {}".format(
            sum_EPE_vizbg / max(num_inst_vizbg, 1.0)))
        logger.info("EPE visible: {}".format(sum_EPE_viz /
                                             max(num_inst_viz, 1.0)))

    logger.info("evaluate pose:")
    imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt)
    # evaluate pose add
    pose_add_plots_dir = os.path.join(imdb_test.result_path, "add_plots")
    mkdir_p(pose_add_plots_dir)
    imdb_test.evaluate_pose_add(config,
                                all_poses_est,
                                all_poses_gt,
                                output_dir=pose_add_plots_dir)
    pose_arp2d_plots_dir = os.path.join(imdb_test.result_path, "arp_2d_plots")
    mkdir_p(pose_arp2d_plots_dir)
    imdb_test.evaluate_pose_arp_2d(config,
                                   all_poses_est,
                                   all_poses_gt,
                                   output_dir=pose_arp2d_plots_dir)

    logger.info("using {} seconds in total".format(time.time() - t_start))
Пример #6
0
    def __call__(self):
        """
        Load light-weight instance annotations of all images into a list of dicts in Detectron2 format.
        Do not load heavy data into memory in this file,
        since we will load the annotations of all images into memory.
        """
        # cache the dataset_dicts to avoid loading masks from files
        hashed_file_name = hashlib.md5(
            ("".join([str(fn) for fn in self.objs]) +
             "dataset_dicts_{}_{}_{}_{}_{}_{}".format(
                 self.name, self.dataset_root,
                 self.with_masks, self.with_depth, self.with_xyz,
                 osp.abspath(__file__))).encode("utf-8")).hexdigest()
        cache_path = osp.join(
            self.dataset_root,
            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name))

        if osp.exists(cache_path) and self.use_cache:
            logger.info("load cached dataset dicts from {}".format(cache_path))
            return mmcv.load(cache_path)

        t_start = time.perf_counter()
        dataset_dicts = []
        self.num_instances_without_valid_segmentation = 0
        self.num_instances_without_valid_box = 0
        logger.info("loading dataset dicts: {}".format(self.name))

        gt_path = osp.join(self.dataset_root, "gt.json")
        assert osp.exists(gt_path), gt_path
        gt_dict = mmcv.load(gt_path)

        if True:
            for str_im_id in tqdm(gt_dict):
                int_im_id = int(str_im_id)
                rgb_path = osp.join(self.dataset_root,
                                    "{:06d}.png").format(int_im_id)
                assert osp.exists(rgb_path), rgb_path

                record = {
                    "dataset_name": self.name,
                    'file_name': osp.relpath(rgb_path, PROJ_ROOT),
                    'height': self.height,
                    'width': self.width,
                    'image_id': int_im_id,
                    "scene_im_id": "{}/{}".format(0,
                                                  int_im_id),  # for evaluation
                    "img_type": 'real'  # NOTE: has background
                }
                insts = []
                for anno_i, anno in enumerate(gt_dict[str_im_id]):
                    obj_id = anno['obj_id']
                    if obj_id not in self.cat_ids:
                        continue
                    cur_label = self.cat2label[obj_id]  # 0-based label

                    bbox_obj = gt_dict[str_im_id][anno_i]['obj_bb']
                    x1, y1, w, h = bbox_obj
                    if self.filter_invalid:
                        if h <= 1 or w <= 1:
                            self.num_instances_without_valid_box += 1
                            continue

                    inst = {
                        'category_id': cur_label,  # 0-based label
                        'bbox': bbox_obj,
                        'bbox_mode': BoxMode.XYWH_ABS,
                    }

                    insts.append(inst)
                if len(insts) == 0:  # filter im without anno
                    continue
                record['annotations'] = insts
                dataset_dicts.append(record)

        if self.num_instances_without_valid_box > 0:
            logger.warning(
                "Filtered out {} instances without valid box. "
                "There might be issues in your dataset generation process.".
                format(self.num_instances_without_valid_box))
        ##########################
        if self.num_to_load > 0:
            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
            dataset_dicts = dataset_dicts[:self.num_to_load]
        logger.info("loaded {} dataset dicts, using {}s".format(
            len(dataset_dicts),
            time.perf_counter() - t_start))

        mkdir_p(osp.dirname(cache_path))
        mmcv.dump(dataset_dicts, cache_path, protocol=4)
        logger.info("Dumped dataset_dicts to {}".format(cache_path))
        return dataset_dicts
Пример #7
0
def plot_vsd_err_hist(eval_dir,
                      scene_ids,
                      obj_id,
                      dataset_name="linemod",
                      top_n=1,
                      delta=15,
                      tau=20,
                      cost="step",
                      cam_type="primesense"):
    # top_n_eval = eval_args.getint('EVALUATION', 'TOP_N_EVAL')
    # top_n = eval_args.getint('METRIC', 'TOP_N')
    # delta = eval_args.getint('METRIC', 'VSD_DELTA')
    # tau = eval_args.getint('METRIC', 'VSD_TAU')
    # cost = eval_args.get('METRIC', 'VSD_COST')
    # cam_type = eval_args.get('DATA', 'cam_type')
    # dataset_name = eval_args.get('DATA', 'dataset')
    # obj_id = eval_args.getint('DATA', 'obj_id')
    # if top_n_eval < 1:
    #     return

    data_params = dataset_params.get_dataset_params(dataset_name,
                                                    model_type="",
                                                    train_type="",
                                                    test_type=cam_type,
                                                    cam_type=cam_type)

    vsd_errs = []
    for scene_id in scene_ids:
        if dataset_name in ["linemod", "hinterstoisser"]:
            # NOTE: linemod scene_id == obj_id
            if obj_id != scene_id:
                continue

        error_file_path = osp.join(
            eval_dir,
            "error=vsd_ntop=%s_delta=%s_tau=%s_cost=%s" %
            (top_n, delta, tau, cost),
            "errors_{:02d}.yml".format(scene_id),
        )

        if not osp.exists(error_file_path):
            print("WARNING: " + error_file_path + " not found")
            continue
        gts = inout.load_gt(data_params["scene_gt_mpath"].format(scene_id))
        visib_gts = inout.load_yaml(data_params["scene_gt_stats_mpath"].format(
            scene_id, 15))  # delta=15
        vsd_dict = inout.load_yaml(error_file_path)
        for view, vsd_e in enumerate(vsd_dict):
            vsds = vsd_dict[view * top_n:(view + 1) * top_n]
            for gt, visib_gt in zip(gts[view], visib_gts[view]):
                if gt["obj_id"] == obj_id:
                    if visib_gt["visib_fract"] > 0.1:
                        for vsd_e in vsds:
                            vsd_errs += [list(vsd_e["errors"].values())[0]]

    if len(vsd_errs) == 0:
        return
    vsd_errs = np.array(vsd_errs)
    logger.info("vsd errs: {}".format(len(vsd_errs)))

    fig = plt.figure()  # noqa
    ax = plt.gca()
    ax.set_xlim((0.0, 1.0))
    plt.grid()
    plt.xlabel("vsd err")
    plt.ylabel("recall")
    plt.title("obj: {}, VSD Error vs Recall".format(obj_id))
    legend = []

    for n in np.unique(np.array([top_n, 1])):

        total_views = int(len(vsd_errs) / top_n)
        min_vsd_errs = np.empty((total_views, ))

        for view in range(total_views):
            top_n_errors = vsd_errs[view * top_n:(view + 1) * top_n]
            if n == 1:
                top_n_errors = top_n_errors[np.newaxis, 0]
            min_vsd_errs[view] = np.min(top_n_errors)

        min_vsd_errs_sorted = np.sort(min_vsd_errs)
        recall = np.float32(np.arange(total_views) + 1.0) / total_views

        # fill curve
        min_vsd_errs_sorted = np.hstack((min_vsd_errs_sorted, np.array([1.0])))
        recall = np.hstack((recall, np.array([1.0])))

        AUC_vsd = np.trapz(recall, min_vsd_errs_sorted)
        plt.plot(min_vsd_errs_sorted, recall)

        legend += ["top {0} vsd err, AUC = {1:.4f}".format(n, AUC_vsd)]
        logger.info("obj:{} top {} vsd err, AUC = {:.4f}".format(
            obj_id, n, AUC_vsd))
    plt.legend(legend)
    out_file = osp.join(eval_dir, "latex",
                        "vsd_err_hist_obj_{:02d}.tex".format(obj_id))
    mkdir_p(osp.dirname(out_file))
    logger.info(osp.basename(out_file))
    tikz_save(out_file,
              figurewidth="0.45\\textheight",
              figureheight="0.45\\textheight",
              show_info=False)