Ejemplo n.º 1
0
def pose_from_predictions_train(pred_rots,
                                pred_transes,
                                eps=1e-4,
                                is_allo=True):
    """for train
    Args:
        pred_rots:
        pred_transes:
        eps:
        is_allo:

    Returns:

    """
    translation = pred_transes

    if pred_rots.ndim == 2 and pred_rots.shape[-1] == 4:
        pred_quats = pred_rots
        quat_allo = pred_quats / (torch.norm(pred_quats, dim=1, keepdim=True) +
                                  eps)
        if is_allo:
            quat_ego = allocentric_to_egocentric_torch(translation,
                                                       quat_allo,
                                                       eps=eps)
        else:
            quat_ego = quat_allo
        rot_ego = quat2mat_torch(quat_ego)
    if pred_rots.ndim == 3 and pred_rots.shape[-1] == 3:  # Nx3x3
        if is_allo:
            rot_ego = allo_to_ego_mat_torch(translation, pred_rots, eps=eps)
        else:
            rot_ego = pred_rots
    return rot_ego, translation
Ejemplo n.º 2
0
def render_dib_vc_batch(ren,
                        Rs,
                        ts,
                        Ks,
                        obj_ids,
                        models,
                        rot_type="quat",
                        H=480,
                        W=640,
                        near=0.01,
                        far=100.0,
                        with_depth=False):
    """
    Args:
        ren: A DIB-renderer
        models: All models loaded by load_objs
    """
    assert ren.mode in ["VertexColorBatch"], ren.mode
    bs = len(Rs)
    if len(Ks) == 1:
        Ks = [Ks[0] for _ in range(bs)]
    ren.set_camera_parameters_from_RT_K(Rs,
                                        ts,
                                        Ks,
                                        height=H,
                                        width=W,
                                        near=near,
                                        far=far,
                                        rot_type=rot_type)
    colors = [models[_id]["colors"] for _id in obj_ids]  # b x [1, p, 3]
    points = [[models[_id]["vertices"], models[_id]["faces"][0].long()]
              for _id in obj_ids]

    # points: list of [vertices, faces]
    # colors: list of colors
    predictions, im_probs, _, im_masks = ren.forward(points=points,
                                                     colors=colors)
    if with_depth:
        # transform xyz
        if not isinstance(Rs, torch.Tensor):
            Rs = torch.stack(Rs)  # list
        if rot_type == "quat":
            R_mats = quat2mat_torch(Rs)
        else:
            R_mats = Rs
        xyzs = [
            transform_pts_Rt_th(models[obj_id]["vertices"][0], R_mats[_id],
                                ts[_id])[None]
            for _id, obj_id in enumerate(obj_ids)
        ]
        ren_xyzs, _, _, _ = ren.forward(points=points, colors=xyzs)
        depth = ren_xyzs[:, :, :, 2]  # bhw
    else:
        depth = None
    # bxhxwx3 rgb, bhw1 prob, bhw1 mask, bhw depth
    return predictions, im_probs, im_masks, depth
Ejemplo n.º 3
0
def test_mat2quat_torch():
    from core.utils.pose_utils import quat2mat_torch

    axis = np.random.rand(3)
    angle = np.random.rand(1)
    # quat = axangle2quat([1, 2, 3], 0.7)
    quat = axangle2quat(axis, angle)
    print("quat:\n", quat)
    mat = quat2mat(quat)
    print("mat:\n", mat)
    mat_th = torch.tensor(mat.astype("float32"))[None].to("cuda")
    print("mat_th:\n", mat_th)
    quat_th = mat2quat_batch(mat_th)
    print("quat_th:\n", quat_th)
    mat_2 = quat2mat_torch(quat_th)
    print("mat_2:\n", mat_2)
    diff_mat = mat_th - mat_2
    print("mat_diff:\n", diff_mat)
    diff_quat = quat - quat_th.cpu().numpy()
    print("diff_quat:\n", diff_quat)
Ejemplo n.º 4
0
def render_dib_tex_batch(ren,
                         Rs,
                         ts,
                         Ks,
                         obj_ids,
                         models,
                         rot_type="quat",
                         H=480,
                         W=640,
                         near=0.01,
                         far=100.0,
                         with_depth=False):
    assert ren.mode in ["TextureBatch"], ren.mode
    bs = len(Rs)
    if len(Ks) == 1:
        Ks = [Ks[0] for _ in range(bs)]
    ren.set_camera_parameters_from_RT_K(Rs,
                                        ts,
                                        Ks,
                                        height=H,
                                        width=W,
                                        near=near,
                                        far=far,
                                        rot_type=rot_type)
    # points: list of [vertices, faces]
    points = [[models[_id]["vertices"], models[_id]["faces"][0].long()]
              for _id in obj_ids]
    uv_bxpx2 = [models[_id]["uvs"] for _id in obj_ids]
    texture_bx3xthxtw = [models[_id]["texture"] for _id in obj_ids]
    ft_fx3_list = [models[_id]["face_textures"][0] for _id in obj_ids]

    # points: list of [vertices, faces]
    # colors: list of colors
    dib_ren_im, dib_ren_prob, _, dib_ren_mask = ren.forward(
        points=points,
        uv_bxpx2=uv_bxpx2,
        texture_bx3xthxtw=texture_bx3xthxtw,
        ft_fx3=ft_fx3_list)

    if with_depth:
        # transform xyz
        if not isinstance(Rs, torch.Tensor):
            Rs = torch.stack(Rs)  # list
        if rot_type == "quat":
            R_mats = quat2mat_torch(Rs)
        else:
            R_mats = Rs
        xyzs = [
            transform_pts_Rt_th(models[obj_id]["vertices"][0], R_mats[_id],
                                ts[_id])[None]
            for _id, obj_id in enumerate(obj_ids)
        ]
        dib_ren_vc_batch = DIBRenderer(height=H,
                                       width=W,
                                       mode="VertexColorBatch")
        dib_ren_vc_batch.set_camera_parameters(ren.camera_params)
        ren_xyzs, _, _, _ = dib_ren_vc_batch.forward(points=points,
                                                     colors=xyzs)
        depth = ren_xyzs[:, :, :, 2]  # bhw
    else:
        depth = None
    return dib_ren_im, dib_ren_prob, dib_ren_mask, depth  # bxhxwx3 rgb, bhw1 prob/mask, bhw depth
Ejemplo n.º 5
0
def pose_from_predictions_train(pred_rots,
                                pred_centroids,
                                pred_z_vals,
                                roi_cams,
                                eps=1e-4,
                                is_allo=True):
    """for train
    Args:
        pred_rots:
        pred_centroids:
        pred_z_vals: [B, 1]
        roi_cams: absolute cams
        eps:
        is_allo:

    Returns:

    """
    if roi_cams.dim() == 2:
        roi_cams.unsqueeze_(0)
    assert roi_cams.dim() == 3, roi_cams.dim()
    # absolute coords
    cx = pred_centroids[:, 0:1]  # [#roi, 1]
    cy = pred_centroids[:, 1:2]  # [#roi, 1]

    z = pred_z_vals

    # backproject regressed centroid with regressed z
    """
    fx * tx + px * tz = z * cx
    fy * ty + py * tz = z * cy
    tz = z
    ==>
    fx * tx / tz = cx - px
    fy * ty / tz = cy - py
    ==>
    tx = (cx - px) * tz / fx
    ty = (cy - py) * tz / fy
    """
    # NOTE: z must be [B,1]
    translation = torch.cat(
        [
            z * (cx - roi_cams[:, 0:1, 2]) / roi_cams[:, 0:1, 0], z *
            (cy - roi_cams[:, 1:2, 2]) / roi_cams[:, 1:2, 1], z
        ],
        dim=1,
    )

    if pred_rots.ndim == 2 and pred_rots.shape[-1] == 4:
        pred_quats = pred_rots
        quat_allo = pred_quats / (torch.norm(pred_quats, dim=1, keepdim=True) +
                                  eps)
        if is_allo:
            quat_ego = allocentric_to_egocentric_torch(translation,
                                                       quat_allo,
                                                       eps=eps)
        else:
            quat_ego = quat_allo
        rot_ego = quat2mat_torch(quat_ego)
    if pred_rots.ndim == 3 and pred_rots.shape[-1] == 3:  # Nx3x3
        if is_allo:
            rot_ego = allo_to_ego_mat_torch(translation, pred_rots, eps=eps)
        else:
            rot_ego = pred_rots
    return rot_ego, translation
Ejemplo n.º 6
0
    def set_camera_parameters_from_RT_K(self,
                                        Rs,
                                        ts,
                                        Ks,
                                        height,
                                        width,
                                        near=0.01,
                                        far=10.0,
                                        rot_type='mat'):
        """
        Rs: a list of rotations tensor
        ts: a list of translations tensor
        Ks: a list of camera intrinsic matrices or a single matrix
        ----
        [cam_view_R, cam_view_pos, cam_proj]
        """
        """
        aspect_ratio = width / height
        fov_x, fov_y = K_to_fov(K, height, width)
        # camera_projection_mtx = perspectiveprojectionnp(self.camera_fov_y,
        #         ratio=aspect_ratio, near=near, far=far)
        camera_projection_mtx = perspectiveprojectionnp(fov_y,
                ratio=aspect_ratio, near=near, far=far)
        """
        assert rot_type in ['mat', 'quat'], rot_type
        bs = len(Rs)
        single_K = False
        if not isinstance(Ks, list) \
                or (isinstance(Ks, (np.ndarray, torch.Tensor)) and Ks.ndim == 2):
            K = Ks
            camera_proj_mtx = projectiveprojection_real(
                K, 0, 0, width, height, near, far)
            camera_proj_mtx = torch.as_tensor(
                camera_proj_mtx).float().cuda()  # 4x4
            single_K = True

        camera_view_mtx = []
        camera_view_shift = []
        if not single_K:
            camera_proj_mtx = []
        for i in range(bs):
            R = Rs[i]
            t = ts[i]
            if not isinstance(R, torch.Tensor):
                R = torch.tensor(R, dtype=torch.float32, device='cuda:0')
            if not isinstance(t, torch.Tensor):
                t = torch.tensor(t, dtype=torch.float32, device='cuda:0')
            if rot_type == 'quat':
                R = quat2mat_torch(R.unsqueeze(0))[0]
            cam_view_R = torch.matmul(self.yz_flip.to(R), R)
            cam_view_t = -torch.matmul(R.t(), t)  # cam pos

            camera_view_mtx.append(cam_view_R)
            camera_view_shift.append(cam_view_t)
            if not single_K:
                K = Ks[i]
                cam_proj_mtx = projectiveprojection_real(
                    K, 0, 0, width, height, near, far)
                cam_proj_mtx = torch.tensor(cam_proj_mtx).float().cuda()  # 4x4
                camera_proj_mtx.append(cam_proj_mtx)
        camera_view_mtx = torch.stack(camera_view_mtx).cuda()  # bx3x3
        camera_view_shift = torch.stack(camera_view_shift).cuda()  # bx3
        if not single_K:
            camera_proj_mtx = torch.stack(camera_proj_mtx)  # bx3x1 or bx4x4

        # print("camera view matrix: \n", camera_view_mtx, camera_view_mtx.shape) # bx3x3, camera rot?
        # print('camera view shift: \n', camera_view_shift, camera_view_shift.shape) # bx3, camera trans?
        # print('camera projection mat: \n', camera_proj_mtx, camera_proj_mtx.shape) # projection matrix, 3x1
        self.camera_params = [
            camera_view_mtx, camera_view_shift, camera_proj_mtx
        ]
Ejemplo n.º 7
0
    def forward(
        self,
        x,
        gt_xyz=None,
        gt_xyz_bin=None,
        gt_mask_trunc=None,
        gt_mask_visib=None,
        gt_mask_obj=None,
        gt_region=None,
        gt_allo_quat=None,
        gt_ego_quat=None,
        gt_allo_rot6d=None,
        gt_ego_rot6d=None,
        gt_ego_rot=None,
        gt_points=None,
        sym_infos=None,
        gt_trans=None,
        gt_trans_ratio=None,
        roi_classes=None,
        roi_coord_2d=None,
        roi_cams=None,
        roi_centers=None,
        roi_whs=None,
        roi_extents=None,
        resize_ratios=None,
        do_loss=False,
    ):
        cfg = self.cfg
        r_head_cfg = cfg.MODEL.CDPN.ROT_HEAD
        t_head_cfg = cfg.MODEL.CDPN.TRANS_HEAD
        pnp_net_cfg = cfg.MODEL.CDPN.PNP_NET

        # x.shape [bs, 3, 256, 256]
        if self.concat:
            features, x_f64, x_f32, x_f16 = self.backbone(
                x)  # features.shape [bs, 2048, 8, 8]
            # joints.shape [bs, 1152, 64, 64]
            mask, coor_x, coor_y, coor_z, region = self.rot_head_net(
                features, x_f64, x_f32, x_f16)
        else:
            features = self.backbone(x)  # features.shape [bs, 2048, 8, 8]
            # joints.shape [bs, 1152, 64, 64]
            mask, coor_x, coor_y, coor_z, region = self.rot_head_net(features)

        # TODO: remove this trans_head_net
        # trans = self.trans_head_net(features)

        device = x.device
        bs = x.shape[0]
        num_classes = r_head_cfg.NUM_CLASSES

        out_res = cfg.MODEL.CDPN.BACKBONE.OUTPUT_RES

        if r_head_cfg.ROT_CLASS_AWARE:
            assert roi_classes is not None
            coor_x = coor_x.view(bs, num_classes, self.r_out_dim // 3, out_res,
                                 out_res)
            coor_x = coor_x[torch.arange(bs).to(device), roi_classes]
            coor_y = coor_y.view(bs, num_classes, self.r_out_dim // 3, out_res,
                                 out_res)
            coor_y = coor_y[torch.arange(bs).to(device), roi_classes]
            coor_z = coor_z.view(bs, num_classes, self.r_out_dim // 3, out_res,
                                 out_res)
            coor_z = coor_z[torch.arange(bs).to(device), roi_classes]

        if r_head_cfg.MASK_CLASS_AWARE:
            assert roi_classes is not None
            mask = mask.view(bs, num_classes, self.mask_out_dim, out_res,
                             out_res)
            mask = mask[torch.arange(bs).to(device), roi_classes]

        if r_head_cfg.REGION_CLASS_AWARE:
            assert roi_classes is not None
            region = region.view(bs, num_classes, self.region_out_dim, out_res,
                                 out_res)
            region = region[torch.arange(bs).to(device), roi_classes]

        # -----------------------------------------------
        # get rot and trans from pnp_net
        # NOTE: use softmax for bins (the last dim is bg)
        if coor_x.shape[1] > 1 and coor_y.shape[1] > 1 and coor_z.shape[1] > 1:
            coor_x_softmax = F.softmax(coor_x[:, :-1, :, :], dim=1)
            coor_y_softmax = F.softmax(coor_y[:, :-1, :, :], dim=1)
            coor_z_softmax = F.softmax(coor_z[:, :-1, :, :], dim=1)
            coor_feat = torch.cat(
                [coor_x_softmax, coor_y_softmax, coor_z_softmax], dim=1)
        else:
            coor_feat = torch.cat([coor_x, coor_y, coor_z], dim=1)  # BCHW

        if pnp_net_cfg.WITH_2D_COORD:
            assert roi_coord_2d is not None
            coor_feat = torch.cat([coor_feat, roi_coord_2d], dim=1)

        # NOTE: for region, the 1st dim is bg
        region_softmax = F.softmax(region[:, 1:, :, :], dim=1)

        mask_atten = None
        if pnp_net_cfg.MASK_ATTENTION != "none":
            mask_atten = get_mask_prob(cfg, mask)

        region_atten = None
        if pnp_net_cfg.REGION_ATTENTION:
            region_atten = region_softmax

        pred_rot_, pred_t_ = self.pnp_net(coor_feat,
                                          region=region_atten,
                                          extents=roi_extents,
                                          mask_attention=mask_atten)
        if pnp_net_cfg.R_ONLY:  # override trans pred
            pred_t_ = self.trans_head_net(features)

        # convert pred_rot to rot mat -------------------------
        rot_type = pnp_net_cfg.ROT_TYPE
        if rot_type in ["ego_quat", "allo_quat"]:
            pred_rot_m = quat2mat_torch(pred_rot_)
        elif rot_type in ["ego_log_quat", "allo_log_quat"]:
            pred_rot_m = quat2mat_torch(quaternion_lf.qexp(pred_rot_))
        elif rot_type in ["ego_lie_vec", "allo_lie_vec"]:
            pred_rot_m = lie_algebra.lie_vec_to_rot(pred_rot_)
        elif rot_type in ["ego_rot6d", "allo_rot6d"]:
            pred_rot_m = ortho6d_to_mat_batch(pred_rot_)
        else:
            raise RuntimeError(f"Wrong pred_rot_ dim: {pred_rot_.shape}")
        # convert pred_rot_m and pred_t to ego pose -----------------------------
        if pnp_net_cfg.TRANS_TYPE == "centroid_z":
            pred_ego_rot, pred_trans = pose_from_pred_centroid_z(
                pred_rot_m,
                pred_centroids=pred_t_[:, :2],
                pred_z_vals=pred_t_[:, 2:3],  # must be [B, 1]
                roi_cams=roi_cams,
                roi_centers=roi_centers,
                resize_ratios=resize_ratios,
                roi_whs=roi_whs,
                eps=1e-4,
                is_allo="allo" in pnp_net_cfg.ROT_TYPE,
                z_type=pnp_net_cfg.Z_TYPE,
                # is_train=True
                is_train=
                do_loss,  # TODO: sometimes we need it to be differentiable during test
            )
        elif pnp_net_cfg.TRANS_TYPE == "centroid_z_abs":
            # abs 2d obj center and abs z
            pred_ego_rot, pred_trans = pose_from_pred_centroid_z_abs(
                pred_rot_m,
                pred_centroids=pred_t_[:, :2],
                pred_z_vals=pred_t_[:, 2:3],  # must be [B, 1]
                roi_cams=roi_cams,
                eps=1e-4,
                is_allo="allo" in pnp_net_cfg.ROT_TYPE,
                # is_train=True
                is_train=
                do_loss,  # TODO: sometimes we need it to be differentiable during test
            )
        elif pnp_net_cfg.TRANS_TYPE == "trans":
            # TODO: maybe denormalize trans
            pred_ego_rot, pred_trans = pose_from_pred(pred_rot_m,
                                                      pred_t_,
                                                      eps=1e-4,
                                                      is_allo="allo"
                                                      in pnp_net_cfg.ROT_TYPE,
                                                      is_train=do_loss)
        else:
            raise ValueError(
                f"Unknown pnp_net trans type: {pnp_net_cfg.TRANS_TYPE}")

        if not do_loss:  # test
            out_dict = {"rot": pred_ego_rot, "trans": pred_trans}
            if cfg.TEST.USE_PNP:
                # TODO: move the pnp/ransac inside forward
                out_dict.update({
                    "mask": mask,
                    "coor_x": coor_x,
                    "coor_y": coor_y,
                    "coor_z": coor_z,
                    "region": region
                })
        else:
            out_dict = {}
            assert ((gt_xyz is not None) and (gt_trans is not None)
                    and (gt_trans_ratio is not None)
                    and (gt_region is not None))
            mean_re, mean_te = compute_mean_re_te(pred_trans, pred_rot_m,
                                                  gt_trans, gt_ego_rot)
            vis_dict = {
                "vis/error_R":
                mean_re,
                "vis/error_t":
                mean_te * 100,  # cm
                "vis/error_tx":
                np.abs(pred_trans[0, 0].detach().item() -
                       gt_trans[0, 0].detach().item()) * 100,  # cm
                "vis/error_ty":
                np.abs(pred_trans[0, 1].detach().item() -
                       gt_trans[0, 1].detach().item()) * 100,  # cm
                "vis/error_tz":
                np.abs(pred_trans[0, 2].detach().item() -
                       gt_trans[0, 2].detach().item()) * 100,  # cm
                "vis/tx_pred":
                pred_trans[0, 0].detach().item(),
                "vis/ty_pred":
                pred_trans[0, 1].detach().item(),
                "vis/tz_pred":
                pred_trans[0, 2].detach().item(),
                "vis/tx_net":
                pred_t_[0, 0].detach().item(),
                "vis/ty_net":
                pred_t_[0, 1].detach().item(),
                "vis/tz_net":
                pred_t_[0, 2].detach().item(),
                "vis/tx_gt":
                gt_trans[0, 0].detach().item(),
                "vis/ty_gt":
                gt_trans[0, 1].detach().item(),
                "vis/tz_gt":
                gt_trans[0, 2].detach().item(),
                "vis/tx_rel_gt":
                gt_trans_ratio[0, 0].detach().item(),
                "vis/ty_rel_gt":
                gt_trans_ratio[0, 1].detach().item(),
                "vis/tz_rel_gt":
                gt_trans_ratio[0, 2].detach().item(),
            }

            loss_dict = self.gdrn_loss(
                cfg=self.cfg,
                out_mask=mask,
                gt_mask_trunc=gt_mask_trunc,
                gt_mask_visib=gt_mask_visib,
                gt_mask_obj=gt_mask_obj,
                out_x=coor_x,
                out_y=coor_y,
                out_z=coor_z,
                gt_xyz=gt_xyz,
                gt_xyz_bin=gt_xyz_bin,
                out_region=region,
                gt_region=gt_region,
                out_trans=pred_trans,
                gt_trans=gt_trans,
                out_rot=pred_ego_rot,
                gt_rot=gt_ego_rot,
                out_centroid=pred_t_[:, :2],  # TODO: get these from trans head
                out_trans_z=pred_t_[:, 2],
                gt_trans_ratio=gt_trans_ratio,
                gt_points=gt_points,
                sym_infos=sym_infos,
                extents=roi_extents,
                # roi_classes=roi_classes,
            )

            if cfg.MODEL.CDPN.USE_MTL:
                for _name in self.loss_names:
                    if f"loss_{_name}" in loss_dict:
                        vis_dict[f"vis_lw/{_name}"] = torch.exp(-getattr(
                            self, f"log_var_{_name}")).detach().item()
            for _k, _v in vis_dict.items():
                if "vis/" in _k or "vis_lw/" in _k:
                    if isinstance(_v, torch.Tensor):
                        _v = _v.item()
                    vis_dict[_k] = _v
            storage = get_event_storage()
            storage.put_scalars(**vis_dict)

            return out_dict, loss_dict
        return out_dict
Ejemplo n.º 8
0
def pose_from_predictions_train(
    pred_rots,
    pred_centroids,
    pred_z_vals,
    roi_cams,
    roi_centers,
    resize_ratios,
    roi_whs,
    eps=1e-4,
    is_allo=True,
    z_type="REL",
):
    """for train
    Args:
        pred_rots:
        pred_centroids:
        pred_z_vals: [B, 1]
        roi_cams: absolute cams
        roi_centers:
        roi_scales:
        roi_whs: (bw,bh) for bboxes
        eps:
        is_allo:
        z_type: REL | ABS | LOG | NEG_LOG

    Returns:

    """
    if roi_cams.dim() == 2:
        roi_cams.unsqueeze_(0)
    assert roi_cams.dim() == 3, roi_cams.dim()
    # absolute coords
    c = torch.stack(
        [
            (pred_centroids[:, 0] * roi_whs[:, 0]) + roi_centers[:, 0],
            (pred_centroids[:, 1] * roi_whs[:, 1]) + roi_centers[:, 1],
        ],
        dim=1,
    )

    cx = c[:, 0:1]  # [#roi, 1]
    cy = c[:, 1:2]  # [#roi, 1]

    # unnormalize regressed z
    if z_type == "ABS":
        z = pred_z_vals
    elif z_type == "REL":
        # z_1 / z_2 = s_2 / s_1 ==> z_1 = s_2 / s_1 * z_2
        z = pred_z_vals * resize_ratios.view(-1, 1)
    else:
        raise ValueError(f"Unknown z_type: {z_type}")

    # backproject regressed centroid with regressed z
    """
    fx * tx + px * tz = z * cx
    fy * ty + py * tz = z * cy
    tz = z
    ==>
    fx * tx / tz = cx - px
    fy * ty / tz = cy - py
    ==>
    tx = (cx - px) * tz / fx
    ty = (cy - py) * tz / fy
    """
    # NOTE: z must be [B,1]
    translation = torch.cat(
        [z * (cx - roi_cams[:, 0:1, 2]) / roi_cams[:, 0:1, 0], z * (cy - roi_cams[:, 1:2, 2]) / roi_cams[:, 1:2, 1], z],
        dim=1,
    )

    if pred_rots.ndim == 2 and pred_rots.shape[-1] == 4:
        pred_quats = pred_rots
        quat_allo = pred_quats / (torch.norm(pred_quats, dim=1, keepdim=True) + eps)
        if is_allo:
            quat_ego = allocentric_to_egocentric_torch(translation, quat_allo, eps=eps)
        else:
            quat_ego = quat_allo
        rot_ego = quat2mat_torch(quat_ego)
    if pred_rots.ndim == 3 and pred_rots.shape[-1] == 3:  # Nx3x3
        if is_allo:
            rot_ego = allo_to_ego_mat_torch(translation, pred_rots, eps=eps)
        else:
            rot_ego = pred_rots
    return rot_ego, translation