Ejemplo n.º 1
0
class batchUpdaterPyMulti:
    def __init__(self, big_cfg, height, width):
        self.big_cfg = big_cfg
        self.model_dir = big_cfg.dataset.model_dir
        self.rot_coord = big_cfg.network.ROT_COORD
        self.pixel_means = big_cfg.network.PIXEL_MEANS[[2, 1, 0]]
        self.pixel_means = self.pixel_means.reshape([3, 1,
                                                     1]).astype(np.float32)
        self.K = big_cfg.dataset.INTRINSIC_MATRIX
        self.T_means = big_cfg.dataset.trans_means
        self.T_stds = big_cfg.dataset.trans_stds
        self.height = height
        self.width = width
        self.zNear = big_cfg.dataset.ZNEAR
        self.zFar = big_cfg.dataset.ZFAR

        self.render_machine = None
        if big_cfg.dataset.dataset.startswith("ModelNet"):
            self.modelnet_root = big_cfg.modelnet_root
            self.texture_path = os.path.join(self.modelnet_root,
                                             "gray_texture.png")
            from lib.render_glumpy.render_py_light_modelnet_multi import (
                Render_Py_Light_ModelNet_Multi, )

            self.model_path_list = [
                os.path.join(self.model_dir, "{}.obj".format(model_name))
                for model_name in big_cfg.dataset.class_name
            ]
            self.render_machine = Render_Py_Light_ModelNet_Multi(
                self.model_path_list,
                self.texture_path,
                self.K,
                self.width,
                self.height,
                self.zNear,
                self.zFar,
                brightness_ratios=[0.7],
            )
        else:
            self.render_machine = Render_Py(
                self.model_dir,
                big_cfg.dataset.class_name,
                self.K,
                self.width,
                self.height,
                self.zNear,
                self.zFar,
            )

        self.reinit = True

        self.batch_size = big_cfg.TRAIN.BATCH_PAIRS  # will update according to data
        self.Kinv = np.linalg.inv(np.matrix(self.K))
        print("build render_machine: ", self.render_machine)

    def get_names(self, big_cfg):
        """

        :param small_cfg:
        :return:
        """
        pred = ["image_observed", "image_rendered"]
        # pred = []
        if big_cfg.network.PRED_FLOW:
            pred.append("flow_est_crop")
            pred.append("flow_loss")
        pred.append("rot_est")
        pred.append("rot_gt")
        pred.append("trans_est")
        pred.append("trans_gt")
        if big_cfg.train_iter.SE3_DIST_LOSS:
            pred.append("rot_loss")
            pred.append("trans_loss")
        if big_cfg.train_iter.SE3_PM_LOSS:
            pred.append("point_matching_loss")
        if self.big_cfg["network"]["PRED_MASK"]:
            pred.append("zoom_mask_prob")
            pred.append("zoom_mask_gt_observed")
            pred.append("mask_pred")  # unzoomed

        return pred

    def forward(self, data_batch, preds, big_cfg):
        """
        :param data:
            image_observed
            image_rendered
            depth_gt_observed
            - depth_observed
            - depth_rendered
            - mask_real
            src_pose
            tgt_pose
        :param label:
            rot_i2r
            trans_i2r
            - flow_i2r
            - flow_i2r_weights
            - point_cloud_model
            - point_cloud_weights
            - point_cloud_real
        :param preds:
            image_observed
            image_rendered
            - flow_i2r_est
            - flow_i2r_loss
            rot_i2r
            trans_i2r
            - rot_i2r_loss
            - trans_i2r_loss
            - point_matching_loss
        :return updated_batch:
        """
        data_array = data_batch.data
        label_array = data_batch.label
        num_ctx = len(data_array)
        pred_names = self.get_names(big_cfg)
        init_time = 0
        render_time = 0
        image_time = 0
        flow_time = 0
        update_time = 0
        mask_time = 0
        io_time = 0
        data_names = [x[0] for x in data_batch.provide_data[0]]
        label_names = [x[0] for x in data_batch.provide_label[0]]
        src_pose_all = [
            data_array[ctx_i][data_names.index("src_pose")].asnumpy()
            for ctx_i in range(num_ctx)
        ]
        tgt_pose_all = [
            data_array[ctx_i][data_names.index("tgt_pose")].asnumpy()
            for ctx_i in range(num_ctx)
        ]
        class_index_all = [
            data_array[ctx_i][data_names.index("class_index")].asnumpy()
            for ctx_i in range(num_ctx)
        ]
        t = time.time()
        # print("pred lens: {}".format(len(preds)))
        # for i in preds:
        #     print(i[0].shape)
        rot_est_all = [
            preds[pred_names.index("rot_est")][ctx_i].asnumpy()
            for ctx_i in range(num_ctx)
        ]
        trans_est_all = [
            preds[pred_names.index("trans_est")][ctx_i].asnumpy()
            for ctx_i in range(num_ctx)
        ]
        init_time += time.time() - t

        if self.big_cfg.network.PRED_FLOW:
            depth_gt_observed_all = [
                data_array[ctx_i][data_names.index(
                    "depth_gt_observed")].asnumpy() for ctx_i in range(num_ctx)
            ]

        for ctx_i in range(num_ctx):
            batch_size = data_array[ctx_i][0].shape[0]
            assert batch_size == self.batch_size, "{} vs. {}".format(
                batch_size, self.batch_size)
            cur_ctx = data_array[ctx_i][0].context
            t = time.time()
            src_pose = src_pose_all[
                ctx_i]  # data_array[ctx_i][data_names.index('src_pose')].asnumpy()
            tgt_pose = tgt_pose_all[
                ctx_i]  # data_array[ctx_i][data_names.index('tgt_pose')].asnumpy()
            if self.big_cfg.network.PRED_FLOW:
                depth_gt_observed = depth_gt_observed_all[
                    ctx_i]  # data_array[ctx_i][data_names.index('depth_gt_observed')] # ndarray

            class_index = class_index_all[
                ctx_i]  # data_array[ctx_i][data_names.index('class_index')].asnumpy()
            rot_est = rot_est_all[
                ctx_i]  # preds[pred_names.index('rot_est')][ctx_i].asnumpy()
            trans_est = trans_est_all[
                ctx_i]  # preds[pred_names.index('trans_est')][ctx_i].asnumpy()
            init_time += time.time() - t

            refined_image_array = np.zeros(
                (batch_size, 3, self.height, self.width))
            refined_depth_array = np.zeros(
                (batch_size, 1, self.height, self.width))
            rot_res_array = np.zeros((batch_size, 4))
            trans_res_array = np.zeros((batch_size, 3))
            refined_pose_array = np.zeros((batch_size, 3, 4))
            KT_array = np.zeros((batch_size, 3, 4))
            for batch_idx in range(batch_size):
                pre_pose = np.squeeze(src_pose[batch_idx])
                r_delta = np.squeeze(rot_est[batch_idx])
                t_delta = np.squeeze(trans_est[batch_idx])

                refined_pose = RT_transform.RT_transform(
                    pre_pose,
                    r_delta,
                    t_delta,
                    self.T_means,
                    self.T_stds,
                    rot_coord=self.rot_coord,
                )
                t = time.time()
                if not self.big_cfg.dataset.dataset.startswith("ModelNet"):
                    refined_image, refined_depth = self.render_machine.render(
                        class_index[batch_idx].astype("int"),
                        refined_pose[:3, :3],
                        refined_pose[:3, 3],
                        r_type="mat",
                    )
                else:
                    idx = 2  # random.randint(0, 100)

                    # generate random light_position
                    if idx % 6 == 0:
                        light_position = [1, 0, 1]
                    elif idx % 6 == 1:
                        light_position = [1, 1, 1]
                    elif idx % 6 == 2:
                        light_position = [0, 1, 1]
                    elif idx % 6 == 3:
                        light_position = [-1, 1, 1]
                    elif idx % 6 == 4:
                        light_position = [-1, 0, 1]
                    elif idx % 6 == 5:
                        light_position = [0, 0, 1]
                    else:
                        raise Exception("???")
                    # print("light_position a: {}".format(light_position))
                    light_position = np.array(light_position) * 0.5
                    # inverse yz
                    light_position[0] += refined_pose[0, 3]
                    light_position[1] -= refined_pose[1, 3]
                    light_position[2] -= refined_pose[2, 3]
                    # print("light_position b: {}".format(light_position))

                    colors = np.array([1, 1, 1])  # white light
                    intensity = np.random.uniform(0.9, 1.1, size=(3, ))
                    colors_randk = 0  # random.randint(0, colors.shape[0] - 1)
                    light_intensity = colors[colors_randk] * intensity
                    # print('light intensity: ', light_intensity)

                    # randomly choose a render machine
                    rm_randk = 0  # random.randint(0, len(brightness_ratios) - 1)
                    refined_image, refined_depth = self.render_machine.render(
                        class_index[batch_idx].astype("int"),
                        refined_pose[:3, :3],
                        refined_pose[:3, 3],
                        light_position,
                        light_intensity,
                        brightness_k=rm_randk,
                        r_type="mat",
                    )
                render_time += time.time() - t

                # process refined_image
                t = time.time()
                refined_image = (refined_image[:, :, [2, 1, 0]].transpose(
                    [2, 0, 1]).astype(np.float32))
                refined_image -= self.pixel_means
                image_time += time.time() - t

                # get se3_res
                rot_res, trans_res = RT_transform.calc_RT_delta(
                    refined_pose,
                    np.squeeze(tgt_pose[batch_idx]),
                    self.T_means,
                    self.T_stds,
                    rot_coord=self.rot_coord,
                    rot_type="QUAT",
                )
                # print('{}, {}: {}, {}'.format(ctx_i, batch_idx, r_delta, rot_res))

                refined_pose_array[batch_idx] = refined_pose
                refined_image_array[batch_idx] = refined_image
                refined_depth_array[batch_idx] = refined_depth.reshape(
                    (1, self.height, self.width))
                rot_res_array[batch_idx] = rot_res
                trans_res_array[batch_idx] = trans_res

                se3_m = np.zeros([3, 4])
                se3_rotm, se3_t = RT_transform.calc_se3(
                    refined_pose, np.squeeze(tgt_pose[batch_idx]))
                se3_m[:, :3] = se3_rotm
                se3_m[:, 3] = se3_t
                KT_array[batch_idx] = np.dot(self.K, se3_m)

            if self.big_cfg.network.PRED_MASK:
                t = time.time()
                refined_mask_rendered_array = np.zeros(
                    refined_depth_array.shape)
                refined_mask_rendered_array[
                    refined_depth_array >
                    0.2] = 1  # if the mask_rendered input is depth
                mask_time += time.time() - t

            update_package = {
                "image_rendered": refined_image_array,
                "depth_rendered": refined_depth_array,
                "src_pose": refined_pose_array,
                "rot": rot_res_array,
                "trans": trans_res_array,
            }
            if self.big_cfg.network.PRED_FLOW:
                t = time.time()
                gpu_flow_machine = gpu_flow_wrapper(cur_ctx.device_id)
                # import matplotlib.pyplot as plt
                # plt.figure()
                # plt.subplot(1,2,1)
                # plt.imshow(refined_depth_array[0,0])
                # plt.subplot(1,2,2)
                # plt.imshow(depth_gt_observed[0,0])
                # plt.show()

                refined_flow, refined_flow_valid = gpu_flow_machine(
                    refined_depth_array.astype(np.float32),
                    depth_gt_observed.astype(np.float32),
                    KT_array.astype(np.float32),
                    np.array(self.Kinv).astype(np.float32),
                )
                # problem with py3
                # print('updater, flow: ', refined_flow.shape, np.unique(refined_flow))
                # print('updater, flow weights: ', refined_flow_valid.shape, np.unique(refined_flow_valid))
                # print('KT: ', KT_array[0])
                # print('Kinv: ', self.Kinv)
                flow_time += time.time() - t
                refined_flow_weights = np.tile(refined_flow_valid,
                                               [1, 2, 1, 1])
                update_package["flow"] = refined_flow
                update_package["flow_weights"] = refined_flow_weights
            if self.big_cfg.network.INPUT_MASK:
                update_package["mask_rendered"] = refined_mask_rendered_array

            t = time.time()
            data_array[ctx_i] = self.update_data_batch(data_array[ctx_i],
                                                       data_names,
                                                       update_package)
            label_array[ctx_i] = self.update_data_batch(
                label_array[ctx_i], label_names, update_package)
            update_time += time.time() - t

        t = time.time()
        new_data_batch = mx.io.DataBatch(
            data=data_array,
            label=label_array,
            pad=data_batch.pad,
            index=data_batch.index,
            provide_data=data_batch.provide_data,
            provide_label=data_batch.provide_label,
        )
        io_time += time.time() - t
        # print("---------------------------------")
        # print("init_time: {:.3f} sec".format(init_time))
        # print("render_time: {:.3f} sec".format(render_time))
        # print("image_time: {:.3f} sec".format(image_time))
        # print("flow_time: {:.3f} sec".format(flow_time))
        # print("mask_time: {:.3f} sec".format(mask_time))
        # print("update_time: {:.3f} sec".format(update_time))
        # print("io_time: {:.3f} sec".format(io_time))
        # print("all_time: {:.3f} sec".format(time.time() - t_all))
        # print("---------------------------------")
        return new_data_batch

    def update_data_batch(self, data, data_names, update_package):
        import mxnet.ndarray as nd

        for blob_idx, blob_name in enumerate(data_names):
            if blob_name not in update_package:
                continue
            # print('blob_idx: {}, blob_name: {} -- {}'.format(blob_idx, blob_name, np.max(update_package[blob_name])))
            data[blob_idx] = nd.array(update_package[blob_name])
        return data
Ejemplo n.º 2
0
def main():
    sel_classes = classes
    model_dir = os.path.join(cur_dir, '../../data/LINEMOD_6D/LM6d_converted/models')
    render_machine = Render_Py(model_dir, classes, K, width, height, ZNEAR, ZFAR)
    for cls_idx, cls_name in enumerate(classes):
        if not cls_name in sel_classes:
            continue
        print(cls_idx, cls_name)
        real_indices = []
        images = [fn for fn in os.listdir(os.path.join(LM6d_origin_root,
                                                       '{:02d}'.format(class2idx(cls_name)), 'rgb')) if '.png' in fn]
        images = sorted(images)

        gt_path = os.path.join(LM6d_origin_root, '{:02d}'.format(class2idx(cls_name)), 'gt.yml')
        gt_dict = load_gt(gt_path)

        info_path = os.path.join(LM6d_origin_root, '{:02d}'.format(class2idx(cls_name)), 'info.yml')
        info_dict = load_info(info_path)

        for real_img in tqdm(images):
            old_color_path = os.path.join(LM6d_origin_root, '{:02d}'.format(class2idx(cls_name)), "rgb/{}".format(real_img))
            assert os.path.exists(old_color_path), old_color_path
            old_depth_path = os.path.join(LM6d_origin_root, '{:02d}'.format(class2idx(cls_name)), "depth/{}".format(real_img))
            assert os.path.exists(old_depth_path), old_depth_path
            img_id = int(real_img.replace('.png', ''))
            new_img_id = img_id + 1

            # K
            # K = np.array(info_dict[img_id]['cam_K']).reshape((3, 3))
            color_img = cv2.imread(old_color_path, cv2.IMREAD_COLOR)

            ## depth
            depth = read_img(old_depth_path, 1)
            # print(np.max(depth), np.min(depth))

            # print(color_img.shape)

            new_color_path = os.path.join(LM6d_new_root, '{:02d}'.format(class2idx(cls_name)),
                                          "{:06d}-color.png".format(new_img_id))
            new_depth_path = os.path.join(LM6d_new_root, '{:02d}'.format(class2idx(cls_name)),
                                          "{:06d}-depth.png".format(new_img_id))
            mkdir_if_missing(os.path.dirname(new_color_path))

            copyfile(old_color_path, new_color_path)
            copyfile(old_depth_path, new_depth_path)

            # meta and label
            meta_dict = {}
            num_instance = len(gt_dict[img_id])
            meta_dict['cls_indexes'] = np.zeros((1, num_instance), dtype=np.int32)
            meta_dict['boxes'] = np.zeros((num_instance, 4), dtype='float32')
            meta_dict['poses'] = np.zeros((3,4,num_instance), dtype='float32')
            distances = []
            label_dict = {}
            for ins_id, instance in enumerate(gt_dict[img_id]):
                obj_id = instance['obj_id']
                meta_dict['cls_indexes'][0, ins_id] = obj_id
                obj_bb = np.array(instance['obj_bb'])
                meta_dict['boxes'][ins_id, :] = obj_bb
                # pose
                pose = np.zeros((3, 4))

                R = np.array(instance['cam_R_m2c']).reshape((3, 3))
                t = np.array(instance['cam_t_m2c']) / 1000.  # mm -> m
                pose[:3, :3] = R
                pose[:3, 3] = t
                distances.append(t[2])
                meta_dict['poses'][:,:,ins_id] = pose
                image_gl, depth_gl = render_machine.render(obj_id-1, pose[:3, :3], pose[:3, 3],
                                                            r_type='mat')
                image_gl = image_gl.astype('uint8')
                label = np.zeros(depth_gl.shape)
                label[depth_gl!=0] = 1
                label_dict[obj_id] = label
            meta_path = os.path.join(LM6d_new_root, '{:02d}'.format(class2idx(cls_name)),
                                          "{:06d}-meta.mat".format(new_img_id))
            sio.savemat(meta_path, meta_dict)

            dis_inds = sorted(range(len(distances)), key=lambda k: -distances[k]) # put deeper objects first
            # label
            res_label = np.zeros((480, 640))
            for dis_id in dis_inds:
                cls_id = meta_dict['cls_indexes'][0, dis_id]
                tmp_label = label_dict[cls_id]
                # label
                res_label[tmp_label == 1] = cls_id

            label_path = os.path.join(LM6d_new_root, '{:02d}'.format(class2idx(cls_name)),
                                      "{:06d}-label.png".format(new_img_id))
            cv2.imwrite(label_path, res_label)
            def vis_check():
                fig = plt.figure(figsize=(8, 6), dpi=120)
                plt.subplot(2, 3, 1)

                plt.imshow(color_img[:,:,[2,1,0]])
                plt.title('color_img')

                plt.subplot(2, 3, 2)
                plt.imshow(depth_gl)
                plt.title('depth')

                plt.subplot(2, 3, 3)
                plt.imshow(depth_gl)
                plt.title('depth_gl')

                plt.subplot(2, 3, 4)
                plt.imshow(res_label)
                plt.title('res_label')

                plt.subplot(2,3,5)
                label_v1_path = os.path.join('/data/wanggu/Storage/LINEMOD_SIXD_wods/LM6d_render_v1/data/real',
                                                 '{:02d}'.format(class2idx(cls_name)),
                                          "{:06d}-label.png".format(new_img_id))
                assert os.path.exists(label_v1_path), label_v1_path
                label_v1 = read_img(label_v1_path, 1)
                plt.imshow(label_v1)
                plt.title('label_v1')

                plt.show()
            # vis_check()

            # real idx
            real_indices.append("{:02d}/{:06d}".format(class2idx(cls_name), new_img_id))

        # one idx file for each video of each class
        real_idx_file = os.path.join(real_set_dir, "{}_all.txt".format(cls_name))
        with open(real_idx_file, 'w') as f:
            for real_idx in real_indices:
                f.write(real_idx + '\n')
Ejemplo n.º 3
0
def main():
    sel_classes = classes
    render_machine = Render_Py(model_dir, classes, K, width, height, ZNEAR,
                               ZFAR)
    for cls_idx, cls_name in enumerate(classes):
        if not cls_name in sel_classes:
            continue
        print(cls_idx, cls_name)
        observed_indices = []
        images = [
            fn for fn in os.listdir(
                os.path.join(LM6d_origin_root, '{:02d}'.format(
                    class2idx(cls_name)), 'rgb')) if '.png' in fn
        ]
        images = sorted(images)

        gt_path = os.path.join(LM6d_origin_root,
                               '{:02d}'.format(class2idx(cls_name)), 'gt.yml')
        gt_dict = load_gt(gt_path)

        info_path = os.path.join(LM6d_origin_root,
                                 '{:02d}'.format(class2idx(cls_name)),
                                 'info.yml')
        info_dict = load_info(info_path)

        for observed_img in tqdm(images):
            old_color_path = os.path.join(LM6d_origin_root,
                                          '{:02d}'.format(class2idx(cls_name)),
                                          "rgb/{}".format(observed_img))
            assert os.path.exists(old_color_path), old_color_path
            old_depth_path = os.path.join(LM6d_origin_root,
                                          '{:02d}'.format(class2idx(cls_name)),
                                          "depth/{}".format(observed_img))
            assert os.path.exists(old_depth_path), old_depth_path
            img_id = int(observed_img.replace('.png', ''))
            new_img_id = img_id + 1

            # K
            # K = np.array(info_dict[img_id]['cam_K']).reshape((3, 3))
            color_img = cv2.imread(old_color_path, cv2.IMREAD_COLOR)

            ## depth
            depth = read_img(old_depth_path, 1)
            # print(np.max(depth), np.min(depth))

            # print(color_img.shape)

            new_color_path = os.path.join(
                LM6d_new_root, '{:02d}'.format(class2idx(cls_name)),
                "{:06d}-color.png".format(new_img_id))
            new_depth_path = os.path.join(
                LM6d_new_root, '{:02d}'.format(class2idx(cls_name)),
                "{:06d}-depth.png".format(new_img_id))
            mkdir_if_missing(os.path.dirname(new_color_path))

            copyfile(old_color_path, new_color_path)
            copyfile(old_depth_path, new_depth_path)

            # meta and label
            meta_dict = {}
            num_instance = len(gt_dict[img_id])
            meta_dict['cls_indexes'] = np.zeros((1, num_instance),
                                                dtype=np.int32)
            meta_dict['boxes'] = np.zeros((num_instance, 4), dtype='float32')
            meta_dict['poses'] = np.zeros((3, 4, num_instance),
                                          dtype='float32')
            distances = []
            label_dict = {}
            for ins_id, instance in enumerate(gt_dict[img_id]):
                obj_id = instance['obj_id']
                meta_dict['cls_indexes'][0, ins_id] = obj_id
                obj_bb = np.array(instance['obj_bb'])
                meta_dict['boxes'][ins_id, :] = obj_bb
                # pose
                pose = np.zeros((3, 4))

                R = np.array(instance['cam_R_m2c']).reshape((3, 3))
                t = np.array(instance['cam_t_m2c']) / 1000.  # mm -> m
                pose[:3, :3] = R
                pose[:3, 3] = t
                distances.append(t[2])
                meta_dict['poses'][:, :, ins_id] = pose
                image_gl, depth_gl = render_machine.render(obj_id - 1,
                                                           pose[:3, :3],
                                                           pose[:3, 3],
                                                           r_type='mat')
                image_gl = image_gl.astype('uint8')
                label = np.zeros(depth_gl.shape)
                label[depth_gl != 0] = 1
                label_dict[obj_id] = label
            meta_path = os.path.join(LM6d_new_root,
                                     '{:02d}'.format(class2idx(cls_name)),
                                     "{:06d}-meta.mat".format(new_img_id))
            sio.savemat(meta_path, meta_dict)

            dis_inds = sorted(
                range(len(distances)),
                key=lambda k: -distances[k])  # put deeper objects first
            # label
            res_label = np.zeros((480, 640))
            for dis_id in dis_inds:
                cls_id = meta_dict['cls_indexes'][0, dis_id]
                tmp_label = label_dict[cls_id]
                # label
                res_label[tmp_label == 1] = cls_id

            label_path = os.path.join(LM6d_new_root,
                                      '{:02d}'.format(class2idx(cls_name)),
                                      "{:06d}-label.png".format(new_img_id))
            cv2.imwrite(label_path, res_label)

            # observed idx
            observed_indices.append("{:02d}/{:06d}".format(
                class2idx(cls_name), new_img_id))
            # print(rot_sym_m)
            rot_res = R_transform(pose_gt[:3, :3], rot_sym_m, rot_coord='model')
            rot_res_q = mat2quat(rot_res)

            rgb_gl, depth_gl = render_machine.render(cls_idx, rot_res_q, pose_gt[:, 3] + p_center)
            rgb_gl = rgb_gl.astype('uint8')
            return rgb_gl, depth_gl

        # transform the points to 2D
        for img_idx in img_indices:
            print(img_idx)
            pose_gt = np.loadtxt('./data/render_v5/data/render_real/{}/{}-pose.txt'.format(cls_name, img_idx),
                                 skiprows=1)
            im_c = cv2.imread('./data/render_v5/data/real/{}-color.png'.format(img_idx))

            im_render_gt, _ = render_machine.render(cls_idx, mat2quat(pose_gt[:3, :3]), pose_gt[:, 3])

            # symmetry axis
            angle = np.pi
            p_center = np.array([0, 0, 0])
            rot_axis = np.array([0, 0, -1])

            im_rot, depth_rot = rotate(angle, rot_axis, pose_gt,  p_center=p_center)
            im_rot_1, depth_rot_1 = rotate(angle=np.pi * 0.5, rot_axis=rot_axis, pose_gt=pose_gt, p_center=p_center)

            fig = plt.figure(figsize=(8, 6), dpi=120)
            plt.subplot(2, 3, 1)
            plt.imshow(im_c[:,:,[2, 1, 0]])
            plt.title('im_render_real')

            plt.subplot(2, 3, 2)