Beispiel #1
0
def stat_se3(pairdb, config):
    '''
    stat mean and std of se3 between real and rendered poses in a pairdb
    :param pairdb:
    :param config:
    :return:
    '''
    num_pair = len(pairdb)
    tic()
    se3_i2r_tensor = np.zeros([
        num_pair, 6
    ]) if config.network.SE3_TYPE == "EULER" else np.zeros([num_pair, 7])
    se3_i2r_dist = np.zeros([num_pair, 2])
    for i in range(num_pair):
        rot_i2r, trans_i2r = calc_RT_delta(pairdb[i]['pose_rendered'],
                                           pairdb[i]['pose_real'], np.zeros(3),
                                           np.ones(3),
                                           config.network.ROT_COORD,
                                           config.network.SE3_TYPE)
        se3_i2r_tensor[i, :-3] = rot_i2r
        se3_i2r_tensor[i, -3:] = trans_i2r

        R_dist, T_dist = calc_rt_dist_m(pairdb[i]['pose_rendered'],
                                        pairdb[i]['pose_real'])
        se3_i2r_dist[i, 0] = R_dist
        se3_i2r_dist[i, 1] = T_dist

    print("stat finished, using {} seconds".format(toc()))
    se3_mean = np.mean(se3_i2r_tensor, axis=0)
    se3_std = np.std(se3_i2r_tensor, axis=0)
    print('mean: {}, \nstd: {}'.format(se3_mean, se3_std))
    print("R_max: {}, T_max: {}".format(np.max(se3_i2r_dist[:, 0]),
                                        np.max(se3_i2r_dist[:, 1])))
    return se3_mean, se3_std
Beispiel #2
0
def main():
    for cls_name in tqdm(sel_classes):
        print(cls_name)
        # if cls_name != 'driller':
        #     continue
        rd_stat = []
        td_stat = []
        pose_observed = []
        pose_rendered = []

        observed_set_file = os.path.join(
            observed_set_dir, "NDtrain_observed_{}.txt".format(cls_name)
        )
        with open(observed_set_file) as f:
            image_list = [x.strip() for x in f.readlines()]

        for data in image_list:
            pose_observed_path = os.path.join(
                gt_observed_dir, cls_name, data + "-pose.txt"
            )
            src_pose_m = np.loadtxt(pose_observed_path, skiprows=1)

            src_euler = np.squeeze(mat2euler(src_pose_m[:, :3]))
            src_quat = euler2quat(src_euler[0], src_euler[1], src_euler[2]).reshape(
                1, -1
            )
            src_trans = src_pose_m[:, 3]
            pose_observed.append((np.hstack((src_quat, src_trans.reshape(1, 3)))))

            for rendered_idx in range(num_rendered_per_observed):
                tgt_euler = src_euler + np.random.normal(0, angle_std / 180 * pi, 3)
                x_error = np.random.normal(0, x_std, 1)[0]
                y_error = np.random.normal(0, y_std, 1)[0]
                z_error = np.random.normal(0, z_std, 1)[0]
                tgt_trans = src_trans + np.array([x_error, y_error, z_error])
                tgt_pose_m = np.hstack(
                    (
                        euler2mat(tgt_euler[0], tgt_euler[1], tgt_euler[2]),
                        tgt_trans.reshape((3, 1)),
                    )
                )
                r_dist, t_dist = calc_rt_dist_m(tgt_pose_m, src_pose_m)
                transform = np.matmul(K, tgt_trans.reshape(3, 1))
                center_x = transform[0] / transform[2]
                center_y = transform[1] / transform[2]
                count = 0
                while r_dist > angle_max or not (
                    48 < center_x < (640 - 48) and 48 < center_y < (480 - 48)
                ):
                    tgt_euler = src_euler + np.random.normal(0, angle_std / 180 * pi, 3)
                    x_error = np.random.normal(0, x_std, 1)[0]
                    y_error = np.random.normal(0, y_std, 1)[0]
                    z_error = np.random.normal(0, z_std, 1)[0]
                    tgt_trans = src_trans + np.array([x_error, y_error, z_error])
                    tgt_pose_m = np.hstack(
                        (
                            euler2mat(tgt_euler[0], tgt_euler[1], tgt_euler[2]),
                            tgt_trans.reshape((3, 1)),
                        )
                    )
                    r_dist, t_dist = calc_rt_dist_m(tgt_pose_m, src_pose_m)
                    transform = np.matmul(K, tgt_trans.reshape(3, 1))
                    center_x = transform[0] / transform[2]
                    center_y = transform[1] / transform[2]
                    count += 1
                    if count == 100:
                        print(rendered_idx)
                        print(
                            "{}: {}, {}, {}, {}".format(
                                data, r_dist, t_dist, center_x, center_y
                            )
                        )
                        print(
                            "count: {}, image_path: {}, rendered_idx: {}".format(
                                count,
                                pose_observed_path.replace("pose.txt", "color.png"),
                                rendered_idx,
                            )
                        )

                tgt_quat = euler2quat(tgt_euler[0], tgt_euler[1], tgt_euler[2]).reshape(
                    1, -1
                )
                pose_rendered.append(np.hstack((tgt_quat, tgt_trans.reshape(1, 3))))
                rd_stat.append(r_dist)
                td_stat.append(t_dist)
        rd_stat = np.array(rd_stat)
        td_stat = np.array(td_stat)
        print("r dist: {} +/- {}".format(np.mean(rd_stat), np.std(rd_stat)))
        print("t dist: {} +/- {}".format(np.mean(td_stat), np.std(td_stat)))

        output_file_name = os.path.join(
            pose_dir,
            "LM6d_occ_dsm_{}_NDtrain_rendered_pose_{}.txt".format(version, cls_name),
        )
        with open(output_file_name, "w") as text_file:
            for x in pose_rendered:
                text_file.write("{}\n".format(" ".join(map(str, np.squeeze(x)))))
    print("{} finished".format(__file__))
Beispiel #3
0
    def evaluate_pose(self, config, all_poses_est, all_poses_gt):
        # evaluate and display
        logger.info("evaluating pose")
        rot_thresh_list = np.arange(1, 11, 1)
        trans_thresh_list = np.arange(0.01, 0.11, 0.01)
        num_metric = len(rot_thresh_list)
        num_iter = config.TEST.test_iter
        rot_acc = np.zeros((self.num_classes, num_iter, num_metric))
        trans_acc = np.zeros((self.num_classes, num_iter, num_metric))
        space_acc = np.zeros((self.num_classes, num_iter, num_metric))

        num_valid_class = 0
        for cls_idx, cls_name in enumerate(self.classes):
            if not (all_poses_est[cls_idx][0] and all_poses_gt[cls_idx][0]):
                continue
            num_valid_class += 1
            for iter_i in range(num_iter):
                curr_poses_gt = all_poses_gt[cls_idx][0]
                num = len(curr_poses_gt)
                curr_poses_est = all_poses_est[cls_idx][iter_i]

                cur_rot_rst = np.zeros((num, 1))
                cur_trans_rst = np.zeros((num, 1))

                for j in range(num):
                    r_dist_est, t_dist_est = calc_rt_dist_m(
                        curr_poses_est[j], curr_poses_gt[j])
                    if cls_name == "eggbox" and r_dist_est > 90:
                        RT_z = np.array([[-1, 0, 0, 0], [0, -1, 0, 0],
                                         [0, 0, 1, 0]])
                        curr_pose_est_sym = se3_mul(curr_poses_est[j], RT_z)
                        r_dist_est, t_dist_est = calc_rt_dist_m(
                            curr_pose_est_sym, curr_poses_gt[j])
                    cur_rot_rst[j, 0] = r_dist_est
                    cur_trans_rst[j, 0] = t_dist_est

                for thresh_idx in range(num_metric):
                    rot_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        cur_rot_rst < rot_thresh_list[thresh_idx])
                    trans_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        cur_trans_rst < trans_thresh_list[thresh_idx])
                    space_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        np.logical_and(
                            cur_rot_rst < rot_thresh_list[thresh_idx],
                            cur_trans_rst < trans_thresh_list[thresh_idx]))

            show_list = [1, 4, 9]
            logger.info("------------ {} -----------".format(cls_name))
            logger.info("{:>24}: {:>7}, {:>7}, {:>7}".format(
                "[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc"))
            for iter_i in range(num_iter):
                logger.info("** iter {} **".format(iter_i + 1))
                logger.info("{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                    "average_accuracy",
                    "[{:>2}, {:>4}]".format(-1, -1),
                    np.mean(rot_acc[cls_idx, iter_i, :]) * 100,
                    np.mean(trans_acc[cls_idx, iter_i, :]) * 100,
                    np.mean(space_acc[cls_idx, iter_i, :]) * 100,
                ))
                for i, show_idx in enumerate(show_list):
                    logger.info(
                        "{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                            "average_accuracy",
                            "[{:>2}, {:>4}]".format(
                                rot_thresh_list[show_idx],
                                trans_thresh_list[show_idx]),
                            rot_acc[cls_idx, iter_i, show_idx] * 100,
                            trans_acc[cls_idx, iter_i, show_idx] * 100,
                            space_acc[cls_idx, iter_i, show_idx] * 100,
                        ))
        print(" ")
        # overall performance
        for iter_i in range(num_iter):
            show_list = [1, 4, 9]
            logger.info(
                "---------- performance over {} classes -----------".format(
                    num_valid_class))
            logger.info("** iter {} **".format(iter_i + 1))
            logger.info("{:>24}: {:>7}, {:>7}, {:>7}".format(
                "[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc"))
            logger.info("{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                "average_accuracy",
                "[{:>2}, {:>4}]".format(-1, -1),
                np.sum(rot_acc[:, iter_i, :]) /
                (num_valid_class * num_metric) * 100,
                np.sum(trans_acc[:, iter_i, :]) /
                (num_valid_class * num_metric) * 100,
                np.sum(space_acc[:, iter_i, :]) /
                (num_valid_class * num_metric) * 100,
            ))
            for i, show_idx in enumerate(show_list):
                logger.info("{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                    "average_accuracy",
                    "[{:>2}, {:>4}]".format(rot_thresh_list[show_idx],
                                            trans_thresh_list[show_idx]),
                    np.sum(rot_acc[:, iter_i, show_idx]) / num_valid_class *
                    100,
                    np.sum(trans_acc[:, iter_i, show_idx]) / num_valid_class *
                    100,
                    np.sum(space_acc[:, iter_i, show_idx]) / num_valid_class *
                    100,
                ))
            print(" ")
Beispiel #4
0
    def get_batch_parallel(self):
        cur_from = self.cur
        cur_to = min(cur_from + self.batch_size, self.size)
        pairdb = [self.pairdb[self.index[i]] for i in range(cur_from, cur_to)]

        # decide multi device slice
        work_load_list = self.work_load_list
        ctx = self.ctx
        if work_load_list is None:
            work_load_list = [1] * len(ctx)
        assert isinstance(work_load_list, list) and len(work_load_list) == len(
            ctx), "Invalid settings for work load. "
        slices = _split_input_slice(self.batch_size, work_load_list)

        multiprocess_results = []
        for idx, islice in enumerate(slices):
            """
            ipairdb = [pairdb[i] for i in range(islice.start, islice.stop)]
            multiprocess_results.append(self.pool.apply_async(
                get_data_pair_train_batch, (ipairdb, self.config)))
            """
            for i in range(islice.start, islice.stop):
                multiprocess_results.append(
                    self.pool.apply_async(get_data_pair_train_batch,
                                          ([pairdb[i]], self.config)))

        if False:
            temp = get_data_pair_train_batch([pairdb[islice.start]],
                                             self.config)  # for debug
            print("**" * 20)
            print(pairdb[0]["image_observed"])
            print("data:")
            for k in temp["data"].keys():
                print("\t{}, {}".format(k, temp["data"][k].shape))
            print("label:")
            for k in temp["label"].keys():
                print("\t{}, {}".format(k, temp["label"][k].shape))
            print(temp["label"]["rot"])
            print(temp["label"]["trans"])
            from lib.pair_matching.RT_transform import calc_rt_dist_m

            r_dist, t_dist = calc_rt_dist_m(temp["data"]["src_pose"][0],
                                            temp["data"]["tgt_pose"][0])
            print("{}: R_dist: {}, T_dist: {}".format(self.cur, r_dist,
                                                      t_dist))
            print("**" * 20)
            image_real = (
                temp["data"]["image_observed"][0].transpose([1, 2, 0]) +
                128).astype(np.uint8)
            print(np.max(image_real))
            print(np.min(image_real))
            image_rendered = (
                temp["data"]["image_rendered"][0].transpose([1, 2, 0]) +
                128).astype(np.uint8)
            mask_real_gt = np.squeeze(temp["label"]["mask_gt_observed"])
            mask_real_est = np.squeeze(temp["data"]["mask_observed"])
            mask_rendered = np.squeeze(temp["data"]["mask_rendered"])
            if "flow" in temp["label"]:
                print("in loader, flow: ", temp["label"]["flow"].shape,
                      np.unique(temp["label"]["flow"]))
                print(
                    "in loader, flow weights: ",
                    temp["label"]["flow_weights"].shape,
                    np.unique(temp["label"]["flow_weights"]),
                )
            import matplotlib.pyplot as plt

            plt.subplot(2, 3, 1)
            plt.imshow(mask_real_est)
            plt.subplot(2, 3, 2)
            plt.imshow(mask_real_gt)
            plt.subplot(2, 3, 3)
            plt.imshow(mask_rendered)
            plt.subplot(2, 3, 4)
            plt.imshow(image_real)
            plt.subplot(2, 3, 5)
            plt.imshow(image_rendered)
            plt.show()
            # plt.savefig('image_real_rendered_{}'.format(self.cur), aspect='normal')

        rst = [
            multiprocess_result.get()
            for multiprocess_result in multiprocess_results
        ]

        batch_per_gpu = int(self.batch_size / len(ctx))
        data_list = []
        label_list = []
        for i in range(len(ctx)):
            sample_data_list = [_["data"] for _ in rst]
            sample_label_list = [_["label"] for _ in rst]
            batch_data = {}
            batch_label = {}
            for key in sample_data_list[0]:
                batch_data[key] = my_tensor_vstack([
                    sample_data_list[j][key]
                    for j in range(i * batch_per_gpu, (i + 1) * batch_per_gpu)
                ])
            for key in sample_label_list[0]:
                batch_label[key] = my_tensor_vstack([
                    sample_label_list[j][key]
                    for j in range(i * batch_per_gpu, (i + 1) * batch_per_gpu)
                ])
            data_list.append(batch_data)
            label_list.append(batch_label)
        """
        data_list = [_['data'] for _ in rst]
        label_list = [_['label'] for _ in rst]
        """
        self.data = [[mx.nd.array(data_on_i[key]) for key in self.data_name]
                     for data_on_i in data_list]
        self.label = [[
            mx.nd.array(label_on_i[key]) for key in self.label_name
        ] for label_on_i in label_list]
Beispiel #5
0
    def evaluate_pose(self, config, all_poses_est, all_poses_gt, logger):
        # evaluate and display
        print_and_log('evaluating pose', logger)
        rot_thresh_list = np.arange(1, 11, 1)
        trans_thresh_list = np.arange(0.01, 0.11, 0.01)
        num_metric = len(rot_thresh_list)
        if config.TEST.AVERAGE_ITERS and config.TEST.test_iter >= 4:
            print_and_log('average last 2 and 4 iters', logger)
            num_iter = config.TEST.test_iter + 2
        else:
            num_iter = config.TEST.test_iter
        rot_acc = np.zeros((self.num_classes, num_iter, num_metric))
        trans_acc = np.zeros((self.num_classes, num_iter, num_metric))
        space_acc = np.zeros((self.num_classes, num_iter, num_metric))

        num_valid_class = 0
        for cls_idx, cls_name in enumerate(self.classes):
            if not (all_poses_est[cls_idx][0] and all_poses_gt[cls_idx][0]):
                continue
            num_valid_class += 1
            for iter_i in range(num_iter):
                curr_poses_gt = all_poses_gt[cls_idx][0]
                num = len(curr_poses_gt)
                if config.TEST.AVERAGE_ITERS and config.TEST.test_iter >= 4:
                    if iter_i == num_iter - 2:
                        curr_poses_est = [
                            0.5 * (all_poses_est[cls_idx][iter_i - 1][j] +
                                   all_poses_est[cls_idx][iter_i - 2][j])
                            for j in range(num)
                        ]
                    elif iter_i == num_iter - 1:
                        curr_poses_est = [
                            0.25 * (all_poses_est[cls_idx][iter_i - 2][j] +
                                    all_poses_est[cls_idx][iter_i - 3][j] +
                                    all_poses_est[cls_idx][iter_i - 4][j] +
                                    all_poses_est[cls_idx][iter_i - 5][j])
                            for j in range(num)
                        ]
                    else:
                        curr_poses_est = all_poses_est[cls_idx][iter_i]
                else:
                    curr_poses_est = all_poses_est[cls_idx][iter_i]

                cur_rot_rst = np.zeros((num, 1))
                cur_trans_rst = np.zeros((num, 1))

                for j in range(num):
                    r_dist_est, t_dist_est = calc_rt_dist_m(
                        curr_poses_est[j], curr_poses_gt[j])
                    if cls_name == 'eggbox' and r_dist_est > 90:
                        RT_z = np.array([[-1, 0, 0, 0], [0, -1, 0, 0],
                                         [0, 0, 1, 0]])
                        curr_pose_est_sym = se3_mul(curr_poses_est[j], RT_z)
                        r_dist_est, t_dist_est = calc_rt_dist_m(
                            curr_pose_est_sym, curr_poses_gt[j])
                        print('eggbox r_dist_est after symmetry: {}'.format(
                            r_dist_est))
                    cur_rot_rst[j, 0] = r_dist_est
                    cur_trans_rst[j, 0] = t_dist_est

                # cur_rot_rst = np.vstack(all_rot_err[cls_idx, iter_i])
                # cur_trans_rst = np.vstack(all_trans_err[cls_idx, iter_i])
                for thresh_idx in range(num_metric):
                    rot_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        cur_rot_rst < rot_thresh_list[thresh_idx])
                    trans_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        cur_trans_rst < trans_thresh_list[thresh_idx])
                    space_acc[cls_idx, iter_i, thresh_idx] = np.mean(
                        np.logical_and(
                            cur_rot_rst < rot_thresh_list[thresh_idx],
                            cur_trans_rst < trans_thresh_list[thresh_idx]))

            show_list = [1, 4, 9]
            print_and_log("------------ {} -----------".format(cls_name),
                          logger)
            print_and_log(
                "{:>24}: {:>7}, {:>7}, {:>7}".format(
                    "[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc"),
                logger)
            for iter_i in range(num_iter):
                print_and_log("** iter {} **".format(iter_i + 1), logger)
                print_and_log(
                    "{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                        'average_accuracy', '[{:>2}, {:>4}]'.format(-1, -1),
                        np.mean(rot_acc[cls_idx, iter_i, :]) * 100,
                        np.mean(trans_acc[cls_idx, iter_i, :]) * 100,
                        np.mean(space_acc[cls_idx, iter_i, :]) * 100), logger)
                for i, show_idx in enumerate(show_list):
                    print_and_log(
                        "{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                            'average_accuracy', '[{:>2}, {:>4}]'.format(
                                rot_thresh_list[show_idx],
                                trans_thresh_list[show_idx]),
                            rot_acc[cls_idx, iter_i, show_idx] * 100,
                            trans_acc[cls_idx, iter_i, show_idx] * 100,
                            space_acc[cls_idx, iter_i, show_idx] * 100),
                        logger)
        print(' ')
        # overall performance
        for iter_i in range(num_iter):
            show_list = [1, 4, 9]
            print_and_log(
                "---------- performance over {} classes -----------".format(
                    num_valid_class), logger)
            print_and_log("** iter {} **".format(iter_i + 1), logger)
            print_and_log(
                "{:>24}: {:>7}, {:>7}, {:>7}".format(
                    "[rot_thresh, trans_thresh", "RotAcc", "TraAcc", "SpcAcc"),
                logger)
            print_and_log(
                "{:<16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                    'average_accuracy', '[{:>2}, {:>4}]'.format(-1, -1),
                    np.sum(rot_acc[:, iter_i, :]) /
                    (num_valid_class * num_metric) * 100,
                    np.sum(trans_acc[:, iter_i, :]) /
                    (num_valid_class * num_metric) * 100,
                    np.sum(space_acc[:, iter_i, :]) /
                    (num_valid_class * num_metric) * 100), logger)
            for i, show_idx in enumerate(show_list):
                print_and_log(
                    "{:>16}{:>8}: {:>7.2f}, {:>7.2f}, {:>7.2f}".format(
                        'average_accuracy',
                        '[{:>2}, {:>4}]'.format(rot_thresh_list[show_idx],
                                                trans_thresh_list[show_idx]),
                        np.sum(rot_acc[:, iter_i, show_idx]) /
                        num_valid_class * 100,
                        np.sum(trans_acc[:, iter_i, show_idx]) /
                        num_valid_class * 100,
                        np.sum(space_acc[:, iter_i, show_idx]) /
                        num_valid_class * 100), logger)
            print(' ')
Beispiel #6
0
        src_quat = euler2quat(src_euler[0], src_euler[1],
                              src_euler[2]).reshape(1, -1)
        src_trans = src_pose_m[:, 3]
        pose_observed.append((np.hstack((src_quat, src_trans.reshape(1, 3)))))

        for rendered_idx in range(num_rendered_per_observed):
            tgt_euler = src_euler + np.random.normal(0, angle_std / 180 * pi,
                                                     3)
            x_error = np.random.normal(0, x_std, 1)[0]
            y_error = np.random.normal(0, y_std, 1)[0]
            z_error = np.random.normal(0, z_std, 1)[0]
            tgt_trans = src_trans + np.array([x_error, y_error, z_error])
            tgt_pose_m = np.hstack(
                (euler2mat(tgt_euler[0], tgt_euler[1],
                           tgt_euler[2]), tgt_trans.reshape((3, 1))))
            r_dist, t_dist = calc_rt_dist_m(tgt_pose_m, src_pose_m)
            transform = np.matmul(K, tgt_trans.reshape(3, 1))
            center_x = transform[0] / transform[2]
            center_y = transform[1] / transform[2]
            count = 0
            while r_dist > angle_max or not (16 < center_x <
                                             (640 - 16) and 16 < center_y <
                                             (480 - 16)):
                tgt_euler = src_euler + np.random.normal(
                    0, angle_std / 180 * pi, 3)
                x_error = np.random.normal(0, x_std, 1)[0]
                y_error = np.random.normal(0, y_std, 1)[0]
                z_error = np.random.normal(0, z_std, 1)[0]
                tgt_trans = src_trans + np.array([x_error, y_error, z_error])
                tgt_pose_m = np.hstack(
                    (euler2mat(tgt_euler[0], tgt_euler[1],
Beispiel #7
0
def pred_eval(config,
              predictor,
              test_data,
              imdb_test,
              vis=False,
              ignore_cache=None,
              logger=None,
              pairdb=None):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb_test: image database
    :param vis: controls visualization
    :param ignore_cache: ignore the saved cache file
    :param logger: the logger instance
    :return:
    """
    logger.info(imdb_test.result_path)
    logger.info("test iter size: {}".format(config.TEST.test_iter))
    pose_err_file = os.path.join(
        imdb_test.result_path,
        imdb_test.name + "_pose_iter{}.pkl".format(config.TEST.test_iter))
    if os.path.exists(pose_err_file) and not ignore_cache and not vis:
        with open(pose_err_file, "rb") as fid:
            if six.PY3:
                [all_rot_err, all_trans_err, all_poses_est,
                 all_poses_gt] = cPickle.load(fid, encoding="latin1")
            else:
                [all_rot_err, all_trans_err, all_poses_est,
                 all_poses_gt] = cPickle.load(fid)
        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt)
        pose_add_plots_dir = os.path.join(imdb_test.result_path, "add_plots")
        mkdir_p(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            "arp_2d_plots")
        mkdir_p(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir)
        return

    assert vis or not test_data.shuffle
    assert config.TEST.BATCH_PAIRS == 1
    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    num_pairs = len(pairdb)
    height = 480
    width = 640

    data_time, net_time, post_time = 0.0, 0.0, 0.0

    sum_EPE_all = 0.0
    num_inst_all = 0.0
    sum_EPE_viz = 0.0
    num_inst_viz = 0.0
    sum_EPE_vizbg = 0.0
    num_inst_vizbg = 0.0
    sum_PoseErr = [
        np.zeros((len(imdb_test.classes) + 1, 2))
        for batch_idx in range(config.TEST.test_iter)
    ]

    all_rot_err = [[[] for j in range(config.TEST.test_iter)]
                   for batch_idx in range(len(imdb_test.classes))
                   ]  # num_cls x test_iter
    all_trans_err = [[[] for j in range(config.TEST.test_iter)]
                     for batch_idx in range(len(imdb_test.classes))]

    all_poses_est = [[[] for j in range(config.TEST.test_iter)]
                     for batch_idx in range(len(imdb_test.classes))]
    all_poses_gt = [[[] for j in range(config.TEST.test_iter)]
                    for batch_idx in range(len(imdb_test.classes))]

    num_inst = np.zeros(len(imdb_test.classes) + 1)

    K = config.dataset.INTRINSIC_MATRIX
    if (config.TEST.test_iter > 1 or config.TEST.VISUALIZE) and True:
        print(
            "************* start setup render_glumpy environment... ******************"
        )
        if config.dataset.dataset.startswith("ModelNet"):
            from lib.render_glumpy.render_py_light_modelnet_multi import Render_Py_Light_ModelNet_Multi

            modelnet_root = config.modelnet_root
            texture_path = os.path.join(modelnet_root, "gray_texture.png")

            model_path_list = [
                os.path.join(config.dataset.model_dir,
                             "{}.obj".format(model_name))
                for model_name in config.dataset.class_name
            ]
            render_machine = Render_Py_Light_ModelNet_Multi(
                model_path_list,
                texture_path,
                K,
                width,
                height,
                config.dataset.ZNEAR,
                config.dataset.ZFAR,
                brightness_ratios=[0.7],
            )
        else:
            render_machine = Render_Py(
                config.dataset.model_dir,
                config.dataset.class_name,
                K,
                width,
                height,
                config.dataset.ZNEAR,
                config.dataset.ZFAR,
            )

        def render(render_machine, pose, cls_idx, K=None):
            if config.dataset.dataset.startswith("ModelNet"):
                idx = 2
                # generate random light_position
                if idx % 6 == 0:
                    light_position = [1, 0, 1]
                elif idx % 6 == 1:
                    light_position = [1, 1, 1]
                elif idx % 6 == 2:
                    light_position = [0, 1, 1]
                elif idx % 6 == 3:
                    light_position = [-1, 1, 1]
                elif idx % 6 == 4:
                    light_position = [-1, 0, 1]
                elif idx % 6 == 5:
                    light_position = [0, 0, 1]
                else:
                    raise Exception("???")
                light_position = np.array(light_position) * 0.5
                # inverse yz
                light_position[0] += pose[0, 3]
                light_position[1] -= pose[1, 3]
                light_position[2] -= pose[2, 3]

                colors = np.array([1, 1, 1])  # white light
                intensity = np.random.uniform(0.9, 1.1, size=(3, ))
                colors_randk = 0
                light_intensity = colors[colors_randk] * intensity

                # randomly choose a render machine
                rm_randk = 0  # random.randint(0, len(brightness_ratios) - 1)
                rgb_gl, depth_gl = render_machine.render(
                    cls_idx,
                    pose[:3, :3],
                    pose[:3, 3],
                    light_position,
                    light_intensity,
                    brightness_k=rm_randk,
                    r_type="mat",
                )
                rgb_gl = rgb_gl.astype("uint8")
            else:
                rgb_gl, depth_gl = render_machine.render(cls_idx,
                                                         pose[:3, :3],
                                                         pose[:, 3],
                                                         r_type="mat",
                                                         K=K)
                rgb_gl = rgb_gl.astype("uint8")
            return rgb_gl, depth_gl

        print(
            "***************setup render_glumpy environment succeed ******************"
        )

    if config.TEST.PRECOMPUTED_ICP:
        print("precomputed_ICP")
        config.TEST.test_iter = 1
        all_rot_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        all_trans_err = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]

        all_poses_est = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]
        all_poses_gt = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]

        xy_trans_err = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]
        z_trans_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        for idx in range(len(pairdb)):
            pose_path = pairdb[idx]["depth_rendered"][:-10] + "-pose_icp.txt"
            pose_rendered_update = np.loadtxt(pose_path, skiprows=1)
            pose_observed = pairdb[idx]["pose_observed"]
            r_dist_est, t_dist_est = calc_rt_dist_m(pose_rendered_update,
                                                    pose_observed)
            xy_dist = np.linalg.norm(pose_rendered_update[:2, -1] -
                                     pose_observed[:2, -1])
            z_dist = np.linalg.norm(pose_rendered_update[-1, -1] -
                                    pose_observed[-1, -1])
            print(
                "{}: r_dist_est: {}, t_dist_est: {}, xy_dist: {}, z_dist: {}".
                format(idx, r_dist_est, t_dist_est, xy_dist, z_dist))
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            # store poses estimation and gt
            all_poses_est[class_id][0].append(pose_rendered_update)
            all_poses_gt[class_id][0].append(pairdb[idx]["pose_observed"])
            all_rot_err[class_id][0].append(r_dist_est)
            all_trans_err[class_id][0].append(t_dist_est)
            xy_trans_err[class_id][0].append(xy_dist)
            z_trans_err[class_id][0].append(z_dist)
        all_rot_err = np.array(all_rot_err)
        all_trans_err = np.array(all_trans_err)
        print("rot = {} +/- {}".format(np.mean(all_rot_err[class_id][0]),
                                       np.std(all_rot_err[class_id][0])))
        print("trans = {} +/- {}".format(np.mean(all_trans_err[class_id][0]),
                                         np.std(all_trans_err[class_id][0])))
        num_list = all_trans_err[class_id][0]
        print("xyz: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))
        num_list = xy_trans_err[class_id][0]
        print("xy: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))
        num_list = z_trans_err[class_id][0]
        print("z: {:.2f} +/- {:.2f}".format(
            np.mean(num_list) * 100,
            np.std(num_list) * 100))

        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt)
        pose_add_plots_dir = os.path.join(imdb_test.result_path,
                                          "add_plots_precomputed_ICP")
        mkdir_p(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            "arp_2d_plots_precomputed_ICP")
        mkdir_p(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir)
        return

    if config.TEST.BEFORE_ICP:
        print("before_ICP")
        config.TEST.test_iter = 1
        all_rot_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        all_trans_err = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]

        all_poses_est = [[[] for j in range(1)]
                         for batch_idx in range(len(imdb_test.classes))]
        all_poses_gt = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]

        xy_trans_err = [[[] for j in range(1)]
                        for batch_idx in range(len(imdb_test.classes))]
        z_trans_err = [[[] for j in range(1)]
                       for batch_idx in range(len(imdb_test.classes))]
        for idx in range(len(pairdb)):
            pose_path = pairdb[idx]["depth_rendered"][:-10] + "-pose.txt"
            pose_rendered_update = np.loadtxt(pose_path, skiprows=1)
            pose_observed = pairdb[idx]["pose_observed"]
            r_dist_est, t_dist_est = calc_rt_dist_m(pose_rendered_update,
                                                    pose_observed)
            xy_dist = np.linalg.norm(pose_rendered_update[:2, -1] -
                                     pose_observed[:2, -1])
            z_dist = np.linalg.norm(pose_rendered_update[-1, -1] -
                                    pose_observed[-1, -1])
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            # store poses estimation and gt
            all_poses_est[class_id][0].append(pose_rendered_update)
            all_poses_gt[class_id][0].append(pairdb[idx]["pose_observed"])
            all_rot_err[class_id][0].append(r_dist_est)
            all_trans_err[class_id][0].append(t_dist_est)
            xy_trans_err[class_id][0].append(xy_dist)
            z_trans_err[class_id][0].append(z_dist)

        all_trans_err = np.array(all_trans_err)
        imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt)
        pose_add_plots_dir = os.path.join(imdb_test.result_path,
                                          "add_plots_before_ICP")
        mkdir_p(pose_add_plots_dir)
        imdb_test.evaluate_pose_add(config,
                                    all_poses_est,
                                    all_poses_gt,
                                    output_dir=pose_add_plots_dir)
        pose_arp2d_plots_dir = os.path.join(imdb_test.result_path,
                                            "arp_2d_plots_before_ICP")
        mkdir_p(pose_arp2d_plots_dir)
        imdb_test.evaluate_pose_arp_2d(config,
                                       all_poses_est,
                                       all_poses_gt,
                                       output_dir=pose_arp2d_plots_dir)
        return

    # ------------------------------------------------------------------------------
    t_start = time.time()
    t = time.time()
    for idx, data_batch in enumerate(test_data):
        if np.sum(pairdb[idx]
                  ["pose_rendered"]) == -12:  # NO POINT VALID IN INIT POSE
            print(idx)
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            for pose_iter_idx in range(config.TEST.test_iter):
                all_poses_est[class_id][pose_iter_idx].append(
                    pairdb[idx]["pose_rendered"])
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]["pose_observed"])

                r_dist = 1000
                t_dist = 1000
                all_rot_err[class_id][pose_iter_idx].append(r_dist)
                all_trans_err[class_id][pose_iter_idx].append(t_dist)
                sum_PoseErr[pose_iter_idx][class_id, :] += np.array(
                    [r_dist, t_dist])
                sum_PoseErr[pose_iter_idx][-1, :] += np.array([r_dist, t_dist])
                # post process
            if idx % 50 == 0:
                logger.info(
                    "testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s".
                    format(
                        (idx + 1),
                        num_pairs,
                        data_time / ((idx + 1) * test_data.batch_size),
                        net_time / ((idx + 1) * test_data.batch_size),
                        post_time / ((idx + 1) * test_data.batch_size),
                    ))
            print("in test: NO POINT_VALID IN rendered")
            continue
        data_time += time.time() - t

        t = time.time()

        pose_rendered = pairdb[idx]["pose_rendered"]
        if np.sum(pose_rendered) == -12:
            print(idx)
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            num_inst[class_id] += 1
            num_inst[-1] += 1
            for pose_iter_idx in range(config.TEST.test_iter):
                all_poses_est[class_id][pose_iter_idx].append(pose_rendered)
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]["pose_observed"])

            # post process
            if idx % 50 == 0:
                logger.info(
                    "testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s".
                    format(
                        (idx + 1),
                        num_pairs,
                        data_time / ((idx + 1) * test_data.batch_size),
                        net_time / ((idx + 1) * test_data.batch_size),
                        post_time / ((idx + 1) * test_data.batch_size),
                    ))

            t = time.time()
            continue

        output_all = predictor.predict(data_batch)
        net_time += time.time() - t

        t = time.time()
        rst_iter = []
        for output in output_all:
            cur_rst = {}
            cur_rst["se3"] = np.squeeze(
                output["se3_output"].asnumpy()).astype("float32")

            if not config.TEST.FAST_TEST and config.network.PRED_FLOW:
                cur_rst["flow"] = np.squeeze(
                    output["flow_est_crop_output"].asnumpy().transpose(
                        (2, 3, 1, 0))).astype("float16")
            else:
                cur_rst["flow"] = None
            if config.network.PRED_MASK and config.TEST.UPDATE_MASK not in [
                    "init", "box_rendered"
            ]:
                mask_pred = np.squeeze(
                    output["mask_observed_pred_output"].asnumpy()).astype(
                        "float32")
                cur_rst["mask_pred"] = mask_pred

            rst_iter.append(cur_rst)

        post_time += time.time() - t
        # sample_ratio = 1  # 0.01
        for batch_idx in range(0, test_data.batch_size):
            # if config.TEST.VISUALIZE and not (r_dist>15 and t_dist>0.05):
            #     continue # 3388, 5326
            # calculate the flow error --------------------------------------------
            t = time.time()
            if config.network.PRED_FLOW and not config.TEST.FAST_TEST:
                # evaluate optical flow
                flow_gt = par_generate_gt(config, pairdb[idx])
                if config.network.PRED_FLOW:
                    all_diff = calc_EPE_one_pair(rst_iter[batch_idx], flow_gt,
                                                 "flow")
                sum_EPE_all += all_diff["epe_all"]
                num_inst_all += all_diff["num_all"]
                sum_EPE_viz += all_diff["epe_viz"]
                num_inst_viz += all_diff["num_viz"]
                sum_EPE_vizbg += all_diff["epe_vizbg"]
                num_inst_vizbg += all_diff["num_vizbg"]

            # calculate the se3 error ---------------------------------------------
            # evaluate se3 estimation
            pose_rendered = pairdb[idx]["pose_rendered"]
            class_id = imdb_test.classes.index(pairdb[idx]["gt_class"])
            num_inst[class_id] += 1
            num_inst[-1] += 1
            post_time += time.time() - t

            # iterative refine se3 estimation --------------------------------------------------
            for pose_iter_idx in range(config.TEST.test_iter):
                t = time.time()
                pose_rendered_update = RT_transform(
                    pose_rendered,
                    rst_iter[0]["se3"][:-3],
                    rst_iter[0]["se3"][-3:],
                    config.dataset.trans_means,
                    config.dataset.trans_stds,
                    config.network.ROT_COORD,
                )

                # calculate error
                r_dist, t_dist = calc_rt_dist_m(pose_rendered_update,
                                                pairdb[idx]["pose_observed"])

                # store poses estimation and gt
                all_poses_est[class_id][pose_iter_idx].append(
                    pose_rendered_update)
                all_poses_gt[class_id][pose_iter_idx].append(
                    pairdb[idx]["pose_observed"])

                all_rot_err[class_id][pose_iter_idx].append(r_dist)
                all_trans_err[class_id][pose_iter_idx].append(t_dist)
                sum_PoseErr[pose_iter_idx][class_id, :] += np.array(
                    [r_dist, t_dist])
                sum_PoseErr[pose_iter_idx][-1, :] += np.array([r_dist, t_dist])
                if config.TEST.VISUALIZE:
                    print("idx {}, iter {}: rError: {}, tError: {}".format(
                        idx + batch_idx, pose_iter_idx + 1, r_dist, t_dist))

                post_time += time.time() - t

                # # if more than one iteration
                if pose_iter_idx < (config.TEST.test_iter -
                                    1) or config.TEST.VISUALIZE:
                    t = time.time()
                    # get refined image
                    K_path = pairdb[idx]["image_observed"][:-10] + "-K.txt"
                    if os.path.exists(K_path):
                        K = np.loadtxt(K_path)
                    image_refined, depth_refined = render(
                        render_machine,
                        pose_rendered_update,
                        config.dataset.class_name.index(
                            pairdb[idx]["gt_class"]),
                        K=K,
                    )
                    image_refined = image_refined[:, :, :3]

                    # update minibatch
                    update_package = [{
                        "image_rendered": image_refined,
                        "src_pose": pose_rendered_update
                    }]
                    if config.network.INPUT_DEPTH:
                        update_package[0]["depth_rendered"] = depth_refined
                    if config.network.INPUT_MASK:
                        mask_rendered_refined = np.zeros(depth_refined.shape)
                        mask_rendered_refined[depth_refined > 0.2] = 1
                        update_package[0][
                            "mask_rendered"] = mask_rendered_refined
                        if config.network.PRED_MASK:
                            # init, box_rendered, mask_rendered, box_observed, mask_observed
                            if config.TEST.UPDATE_MASK == "box_rendered":
                                input_names = [
                                    blob_name[0]
                                    for blob_name in data_batch.provide_data[0]
                                ]
                                update_package[0][
                                    "mask_observed"] = np.squeeze(
                                        data_batch.data[0][input_names.index(
                                            "mask_rendered")].asnumpy()
                                        [batch_idx])  # noqa
                            elif config.TEST.UPDATE_MASK == "init":
                                pass
                            else:
                                raise Exception(
                                    "Unknown UPDATE_MASK type: {}".format(
                                        config.network.UPDATE_MASK))

                    pose_rendered = pose_rendered_update
                    data_batch = update_data_batch(config, data_batch,
                                                   update_package)

                    data_time += time.time() - t

                    # forward and get rst
                    if pose_iter_idx < config.TEST.test_iter - 1:
                        t = time.time()
                        output_all = predictor.predict(data_batch)
                        net_time += time.time() - t

                        t = time.time()
                        rst_iter = []
                        for output in output_all:
                            cur_rst = {}
                            if config.network.REGRESSOR_NUM == 1:
                                cur_rst["se3"] = np.squeeze(
                                    output["se3_output"].asnumpy()).astype(
                                        "float32")

                            if not config.TEST.FAST_TEST and config.network.PRED_FLOW:
                                cur_rst["flow"] = np.squeeze(
                                    output["flow_est_crop_output"].asnumpy().
                                    transpose((2, 3, 1, 0))).astype("float16")
                            else:
                                cur_rst["flow"] = None

                            if config.network.PRED_MASK and config.TEST.UPDATE_MASK not in [
                                    "init", "box_rendered"
                            ]:
                                mask_pred = np.squeeze(
                                    output["mask_observed_pred_output"].
                                    asnumpy()).astype("float32")
                                cur_rst["mask_pred"] = mask_pred

                            rst_iter.append(cur_rst)
                            post_time += time.time() - t

        # post process
        if idx % 50 == 0:
            logger.info(
                "testing {}/{} data {:.4f}s net {:.4f}s calc_gt {:.4f}s".
                format(
                    (idx + 1),
                    num_pairs,
                    data_time / ((idx + 1) * test_data.batch_size),
                    net_time / ((idx + 1) * test_data.batch_size),
                    post_time / ((idx + 1) * test_data.batch_size),
                ))

        t = time.time()

    all_rot_err = np.array(all_rot_err)
    all_trans_err = np.array(all_trans_err)

    # save inference results
    if not config.TEST.VISUALIZE:
        with open(pose_err_file, "wb") as f:
            logger.info("saving result cache to {}".format(pose_err_file))
            cPickle.dump(
                [all_rot_err, all_trans_err, all_poses_est, all_poses_gt],
                f,
                protocol=2)
            logger.info("done")

    if config.network.PRED_FLOW:
        logger.info("evaluate flow:")
        logger.info("EPE all: {}".format(sum_EPE_all / max(num_inst_all, 1.0)))
        logger.info("EPE ignore unvisible: {}".format(
            sum_EPE_vizbg / max(num_inst_vizbg, 1.0)))
        logger.info("EPE visible: {}".format(sum_EPE_viz /
                                             max(num_inst_viz, 1.0)))

    logger.info("evaluate pose:")
    imdb_test.evaluate_pose(config, all_poses_est, all_poses_gt)
    # evaluate pose add
    pose_add_plots_dir = os.path.join(imdb_test.result_path, "add_plots")
    mkdir_p(pose_add_plots_dir)
    imdb_test.evaluate_pose_add(config,
                                all_poses_est,
                                all_poses_gt,
                                output_dir=pose_add_plots_dir)
    pose_arp2d_plots_dir = os.path.join(imdb_test.result_path, "arp_2d_plots")
    mkdir_p(pose_arp2d_plots_dir)
    imdb_test.evaluate_pose_arp_2d(config,
                                   all_poses_est,
                                   all_poses_gt,
                                   output_dir=pose_arp2d_plots_dir)

    logger.info("using {} seconds in total".format(time.time() - t_start))