def pose_from_predictions_train(pred_rots, pred_transes, eps=1e-4, is_allo=True): """for train Args: pred_rots: pred_transes: eps: is_allo: Returns: """ translation = pred_transes if pred_rots.ndim == 2 and pred_rots.shape[-1] == 4: pred_quats = pred_rots quat_allo = pred_quats / (torch.norm(pred_quats, dim=1, keepdim=True) + eps) if is_allo: quat_ego = allocentric_to_egocentric_torch(translation, quat_allo, eps=eps) else: quat_ego = quat_allo rot_ego = quat2mat_torch(quat_ego) if pred_rots.ndim == 3 and pred_rots.shape[-1] == 3: # Nx3x3 if is_allo: rot_ego = allo_to_ego_mat_torch(translation, pred_rots, eps=eps) else: rot_ego = pred_rots return rot_ego, translation
def render_dib_vc_batch(ren, Rs, ts, Ks, obj_ids, models, rot_type="quat", H=480, W=640, near=0.01, far=100.0, with_depth=False): """ Args: ren: A DIB-renderer models: All models loaded by load_objs """ assert ren.mode in ["VertexColorBatch"], ren.mode bs = len(Rs) if len(Ks) == 1: Ks = [Ks[0] for _ in range(bs)] ren.set_camera_parameters_from_RT_K(Rs, ts, Ks, height=H, width=W, near=near, far=far, rot_type=rot_type) colors = [models[_id]["colors"] for _id in obj_ids] # b x [1, p, 3] points = [[models[_id]["vertices"], models[_id]["faces"][0].long()] for _id in obj_ids] # points: list of [vertices, faces] # colors: list of colors predictions, im_probs, _, im_masks = ren.forward(points=points, colors=colors) if with_depth: # transform xyz if not isinstance(Rs, torch.Tensor): Rs = torch.stack(Rs) # list if rot_type == "quat": R_mats = quat2mat_torch(Rs) else: R_mats = Rs xyzs = [ transform_pts_Rt_th(models[obj_id]["vertices"][0], R_mats[_id], ts[_id])[None] for _id, obj_id in enumerate(obj_ids) ] ren_xyzs, _, _, _ = ren.forward(points=points, colors=xyzs) depth = ren_xyzs[:, :, :, 2] # bhw else: depth = None # bxhxwx3 rgb, bhw1 prob, bhw1 mask, bhw depth return predictions, im_probs, im_masks, depth
def test_mat2quat_torch(): from core.utils.pose_utils import quat2mat_torch axis = np.random.rand(3) angle = np.random.rand(1) # quat = axangle2quat([1, 2, 3], 0.7) quat = axangle2quat(axis, angle) print("quat:\n", quat) mat = quat2mat(quat) print("mat:\n", mat) mat_th = torch.tensor(mat.astype("float32"))[None].to("cuda") print("mat_th:\n", mat_th) quat_th = mat2quat_batch(mat_th) print("quat_th:\n", quat_th) mat_2 = quat2mat_torch(quat_th) print("mat_2:\n", mat_2) diff_mat = mat_th - mat_2 print("mat_diff:\n", diff_mat) diff_quat = quat - quat_th.cpu().numpy() print("diff_quat:\n", diff_quat)
def render_dib_tex_batch(ren, Rs, ts, Ks, obj_ids, models, rot_type="quat", H=480, W=640, near=0.01, far=100.0, with_depth=False): assert ren.mode in ["TextureBatch"], ren.mode bs = len(Rs) if len(Ks) == 1: Ks = [Ks[0] for _ in range(bs)] ren.set_camera_parameters_from_RT_K(Rs, ts, Ks, height=H, width=W, near=near, far=far, rot_type=rot_type) # points: list of [vertices, faces] points = [[models[_id]["vertices"], models[_id]["faces"][0].long()] for _id in obj_ids] uv_bxpx2 = [models[_id]["uvs"] for _id in obj_ids] texture_bx3xthxtw = [models[_id]["texture"] for _id in obj_ids] ft_fx3_list = [models[_id]["face_textures"][0] for _id in obj_ids] # points: list of [vertices, faces] # colors: list of colors dib_ren_im, dib_ren_prob, _, dib_ren_mask = ren.forward( points=points, uv_bxpx2=uv_bxpx2, texture_bx3xthxtw=texture_bx3xthxtw, ft_fx3=ft_fx3_list) if with_depth: # transform xyz if not isinstance(Rs, torch.Tensor): Rs = torch.stack(Rs) # list if rot_type == "quat": R_mats = quat2mat_torch(Rs) else: R_mats = Rs xyzs = [ transform_pts_Rt_th(models[obj_id]["vertices"][0], R_mats[_id], ts[_id])[None] for _id, obj_id in enumerate(obj_ids) ] dib_ren_vc_batch = DIBRenderer(height=H, width=W, mode="VertexColorBatch") dib_ren_vc_batch.set_camera_parameters(ren.camera_params) ren_xyzs, _, _, _ = dib_ren_vc_batch.forward(points=points, colors=xyzs) depth = ren_xyzs[:, :, :, 2] # bhw else: depth = None return dib_ren_im, dib_ren_prob, dib_ren_mask, depth # bxhxwx3 rgb, bhw1 prob/mask, bhw depth
def pose_from_predictions_train(pred_rots, pred_centroids, pred_z_vals, roi_cams, eps=1e-4, is_allo=True): """for train Args: pred_rots: pred_centroids: pred_z_vals: [B, 1] roi_cams: absolute cams eps: is_allo: Returns: """ if roi_cams.dim() == 2: roi_cams.unsqueeze_(0) assert roi_cams.dim() == 3, roi_cams.dim() # absolute coords cx = pred_centroids[:, 0:1] # [#roi, 1] cy = pred_centroids[:, 1:2] # [#roi, 1] z = pred_z_vals # backproject regressed centroid with regressed z """ fx * tx + px * tz = z * cx fy * ty + py * tz = z * cy tz = z ==> fx * tx / tz = cx - px fy * ty / tz = cy - py ==> tx = (cx - px) * tz / fx ty = (cy - py) * tz / fy """ # NOTE: z must be [B,1] translation = torch.cat( [ z * (cx - roi_cams[:, 0:1, 2]) / roi_cams[:, 0:1, 0], z * (cy - roi_cams[:, 1:2, 2]) / roi_cams[:, 1:2, 1], z ], dim=1, ) if pred_rots.ndim == 2 and pred_rots.shape[-1] == 4: pred_quats = pred_rots quat_allo = pred_quats / (torch.norm(pred_quats, dim=1, keepdim=True) + eps) if is_allo: quat_ego = allocentric_to_egocentric_torch(translation, quat_allo, eps=eps) else: quat_ego = quat_allo rot_ego = quat2mat_torch(quat_ego) if pred_rots.ndim == 3 and pred_rots.shape[-1] == 3: # Nx3x3 if is_allo: rot_ego = allo_to_ego_mat_torch(translation, pred_rots, eps=eps) else: rot_ego = pred_rots return rot_ego, translation
def set_camera_parameters_from_RT_K(self, Rs, ts, Ks, height, width, near=0.01, far=10.0, rot_type='mat'): """ Rs: a list of rotations tensor ts: a list of translations tensor Ks: a list of camera intrinsic matrices or a single matrix ---- [cam_view_R, cam_view_pos, cam_proj] """ """ aspect_ratio = width / height fov_x, fov_y = K_to_fov(K, height, width) # camera_projection_mtx = perspectiveprojectionnp(self.camera_fov_y, # ratio=aspect_ratio, near=near, far=far) camera_projection_mtx = perspectiveprojectionnp(fov_y, ratio=aspect_ratio, near=near, far=far) """ assert rot_type in ['mat', 'quat'], rot_type bs = len(Rs) single_K = False if not isinstance(Ks, list) \ or (isinstance(Ks, (np.ndarray, torch.Tensor)) and Ks.ndim == 2): K = Ks camera_proj_mtx = projectiveprojection_real( K, 0, 0, width, height, near, far) camera_proj_mtx = torch.as_tensor( camera_proj_mtx).float().cuda() # 4x4 single_K = True camera_view_mtx = [] camera_view_shift = [] if not single_K: camera_proj_mtx = [] for i in range(bs): R = Rs[i] t = ts[i] if not isinstance(R, torch.Tensor): R = torch.tensor(R, dtype=torch.float32, device='cuda:0') if not isinstance(t, torch.Tensor): t = torch.tensor(t, dtype=torch.float32, device='cuda:0') if rot_type == 'quat': R = quat2mat_torch(R.unsqueeze(0))[0] cam_view_R = torch.matmul(self.yz_flip.to(R), R) cam_view_t = -torch.matmul(R.t(), t) # cam pos camera_view_mtx.append(cam_view_R) camera_view_shift.append(cam_view_t) if not single_K: K = Ks[i] cam_proj_mtx = projectiveprojection_real( K, 0, 0, width, height, near, far) cam_proj_mtx = torch.tensor(cam_proj_mtx).float().cuda() # 4x4 camera_proj_mtx.append(cam_proj_mtx) camera_view_mtx = torch.stack(camera_view_mtx).cuda() # bx3x3 camera_view_shift = torch.stack(camera_view_shift).cuda() # bx3 if not single_K: camera_proj_mtx = torch.stack(camera_proj_mtx) # bx3x1 or bx4x4 # print("camera view matrix: \n", camera_view_mtx, camera_view_mtx.shape) # bx3x3, camera rot? # print('camera view shift: \n', camera_view_shift, camera_view_shift.shape) # bx3, camera trans? # print('camera projection mat: \n', camera_proj_mtx, camera_proj_mtx.shape) # projection matrix, 3x1 self.camera_params = [ camera_view_mtx, camera_view_shift, camera_proj_mtx ]
def forward( self, x, gt_xyz=None, gt_xyz_bin=None, gt_mask_trunc=None, gt_mask_visib=None, gt_mask_obj=None, gt_region=None, gt_allo_quat=None, gt_ego_quat=None, gt_allo_rot6d=None, gt_ego_rot6d=None, gt_ego_rot=None, gt_points=None, sym_infos=None, gt_trans=None, gt_trans_ratio=None, roi_classes=None, roi_coord_2d=None, roi_cams=None, roi_centers=None, roi_whs=None, roi_extents=None, resize_ratios=None, do_loss=False, ): cfg = self.cfg r_head_cfg = cfg.MODEL.CDPN.ROT_HEAD t_head_cfg = cfg.MODEL.CDPN.TRANS_HEAD pnp_net_cfg = cfg.MODEL.CDPN.PNP_NET # x.shape [bs, 3, 256, 256] if self.concat: features, x_f64, x_f32, x_f16 = self.backbone( x) # features.shape [bs, 2048, 8, 8] # joints.shape [bs, 1152, 64, 64] mask, coor_x, coor_y, coor_z, region = self.rot_head_net( features, x_f64, x_f32, x_f16) else: features = self.backbone(x) # features.shape [bs, 2048, 8, 8] # joints.shape [bs, 1152, 64, 64] mask, coor_x, coor_y, coor_z, region = self.rot_head_net(features) # TODO: remove this trans_head_net # trans = self.trans_head_net(features) device = x.device bs = x.shape[0] num_classes = r_head_cfg.NUM_CLASSES out_res = cfg.MODEL.CDPN.BACKBONE.OUTPUT_RES if r_head_cfg.ROT_CLASS_AWARE: assert roi_classes is not None coor_x = coor_x.view(bs, num_classes, self.r_out_dim // 3, out_res, out_res) coor_x = coor_x[torch.arange(bs).to(device), roi_classes] coor_y = coor_y.view(bs, num_classes, self.r_out_dim // 3, out_res, out_res) coor_y = coor_y[torch.arange(bs).to(device), roi_classes] coor_z = coor_z.view(bs, num_classes, self.r_out_dim // 3, out_res, out_res) coor_z = coor_z[torch.arange(bs).to(device), roi_classes] if r_head_cfg.MASK_CLASS_AWARE: assert roi_classes is not None mask = mask.view(bs, num_classes, self.mask_out_dim, out_res, out_res) mask = mask[torch.arange(bs).to(device), roi_classes] if r_head_cfg.REGION_CLASS_AWARE: assert roi_classes is not None region = region.view(bs, num_classes, self.region_out_dim, out_res, out_res) region = region[torch.arange(bs).to(device), roi_classes] # ----------------------------------------------- # get rot and trans from pnp_net # NOTE: use softmax for bins (the last dim is bg) if coor_x.shape[1] > 1 and coor_y.shape[1] > 1 and coor_z.shape[1] > 1: coor_x_softmax = F.softmax(coor_x[:, :-1, :, :], dim=1) coor_y_softmax = F.softmax(coor_y[:, :-1, :, :], dim=1) coor_z_softmax = F.softmax(coor_z[:, :-1, :, :], dim=1) coor_feat = torch.cat( [coor_x_softmax, coor_y_softmax, coor_z_softmax], dim=1) else: coor_feat = torch.cat([coor_x, coor_y, coor_z], dim=1) # BCHW if pnp_net_cfg.WITH_2D_COORD: assert roi_coord_2d is not None coor_feat = torch.cat([coor_feat, roi_coord_2d], dim=1) # NOTE: for region, the 1st dim is bg region_softmax = F.softmax(region[:, 1:, :, :], dim=1) mask_atten = None if pnp_net_cfg.MASK_ATTENTION != "none": mask_atten = get_mask_prob(cfg, mask) region_atten = None if pnp_net_cfg.REGION_ATTENTION: region_atten = region_softmax pred_rot_, pred_t_ = self.pnp_net(coor_feat, region=region_atten, extents=roi_extents, mask_attention=mask_atten) if pnp_net_cfg.R_ONLY: # override trans pred pred_t_ = self.trans_head_net(features) # convert pred_rot to rot mat ------------------------- rot_type = pnp_net_cfg.ROT_TYPE if rot_type in ["ego_quat", "allo_quat"]: pred_rot_m = quat2mat_torch(pred_rot_) elif rot_type in ["ego_log_quat", "allo_log_quat"]: pred_rot_m = quat2mat_torch(quaternion_lf.qexp(pred_rot_)) elif rot_type in ["ego_lie_vec", "allo_lie_vec"]: pred_rot_m = lie_algebra.lie_vec_to_rot(pred_rot_) elif rot_type in ["ego_rot6d", "allo_rot6d"]: pred_rot_m = ortho6d_to_mat_batch(pred_rot_) else: raise RuntimeError(f"Wrong pred_rot_ dim: {pred_rot_.shape}") # convert pred_rot_m and pred_t to ego pose ----------------------------- if pnp_net_cfg.TRANS_TYPE == "centroid_z": pred_ego_rot, pred_trans = pose_from_pred_centroid_z( pred_rot_m, pred_centroids=pred_t_[:, :2], pred_z_vals=pred_t_[:, 2:3], # must be [B, 1] roi_cams=roi_cams, roi_centers=roi_centers, resize_ratios=resize_ratios, roi_whs=roi_whs, eps=1e-4, is_allo="allo" in pnp_net_cfg.ROT_TYPE, z_type=pnp_net_cfg.Z_TYPE, # is_train=True is_train= do_loss, # TODO: sometimes we need it to be differentiable during test ) elif pnp_net_cfg.TRANS_TYPE == "centroid_z_abs": # abs 2d obj center and abs z pred_ego_rot, pred_trans = pose_from_pred_centroid_z_abs( pred_rot_m, pred_centroids=pred_t_[:, :2], pred_z_vals=pred_t_[:, 2:3], # must be [B, 1] roi_cams=roi_cams, eps=1e-4, is_allo="allo" in pnp_net_cfg.ROT_TYPE, # is_train=True is_train= do_loss, # TODO: sometimes we need it to be differentiable during test ) elif pnp_net_cfg.TRANS_TYPE == "trans": # TODO: maybe denormalize trans pred_ego_rot, pred_trans = pose_from_pred(pred_rot_m, pred_t_, eps=1e-4, is_allo="allo" in pnp_net_cfg.ROT_TYPE, is_train=do_loss) else: raise ValueError( f"Unknown pnp_net trans type: {pnp_net_cfg.TRANS_TYPE}") if not do_loss: # test out_dict = {"rot": pred_ego_rot, "trans": pred_trans} if cfg.TEST.USE_PNP: # TODO: move the pnp/ransac inside forward out_dict.update({ "mask": mask, "coor_x": coor_x, "coor_y": coor_y, "coor_z": coor_z, "region": region }) else: out_dict = {} assert ((gt_xyz is not None) and (gt_trans is not None) and (gt_trans_ratio is not None) and (gt_region is not None)) mean_re, mean_te = compute_mean_re_te(pred_trans, pred_rot_m, gt_trans, gt_ego_rot) vis_dict = { "vis/error_R": mean_re, "vis/error_t": mean_te * 100, # cm "vis/error_tx": np.abs(pred_trans[0, 0].detach().item() - gt_trans[0, 0].detach().item()) * 100, # cm "vis/error_ty": np.abs(pred_trans[0, 1].detach().item() - gt_trans[0, 1].detach().item()) * 100, # cm "vis/error_tz": np.abs(pred_trans[0, 2].detach().item() - gt_trans[0, 2].detach().item()) * 100, # cm "vis/tx_pred": pred_trans[0, 0].detach().item(), "vis/ty_pred": pred_trans[0, 1].detach().item(), "vis/tz_pred": pred_trans[0, 2].detach().item(), "vis/tx_net": pred_t_[0, 0].detach().item(), "vis/ty_net": pred_t_[0, 1].detach().item(), "vis/tz_net": pred_t_[0, 2].detach().item(), "vis/tx_gt": gt_trans[0, 0].detach().item(), "vis/ty_gt": gt_trans[0, 1].detach().item(), "vis/tz_gt": gt_trans[0, 2].detach().item(), "vis/tx_rel_gt": gt_trans_ratio[0, 0].detach().item(), "vis/ty_rel_gt": gt_trans_ratio[0, 1].detach().item(), "vis/tz_rel_gt": gt_trans_ratio[0, 2].detach().item(), } loss_dict = self.gdrn_loss( cfg=self.cfg, out_mask=mask, gt_mask_trunc=gt_mask_trunc, gt_mask_visib=gt_mask_visib, gt_mask_obj=gt_mask_obj, out_x=coor_x, out_y=coor_y, out_z=coor_z, gt_xyz=gt_xyz, gt_xyz_bin=gt_xyz_bin, out_region=region, gt_region=gt_region, out_trans=pred_trans, gt_trans=gt_trans, out_rot=pred_ego_rot, gt_rot=gt_ego_rot, out_centroid=pred_t_[:, :2], # TODO: get these from trans head out_trans_z=pred_t_[:, 2], gt_trans_ratio=gt_trans_ratio, gt_points=gt_points, sym_infos=sym_infos, extents=roi_extents, # roi_classes=roi_classes, ) if cfg.MODEL.CDPN.USE_MTL: for _name in self.loss_names: if f"loss_{_name}" in loss_dict: vis_dict[f"vis_lw/{_name}"] = torch.exp(-getattr( self, f"log_var_{_name}")).detach().item() for _k, _v in vis_dict.items(): if "vis/" in _k or "vis_lw/" in _k: if isinstance(_v, torch.Tensor): _v = _v.item() vis_dict[_k] = _v storage = get_event_storage() storage.put_scalars(**vis_dict) return out_dict, loss_dict return out_dict
def pose_from_predictions_train( pred_rots, pred_centroids, pred_z_vals, roi_cams, roi_centers, resize_ratios, roi_whs, eps=1e-4, is_allo=True, z_type="REL", ): """for train Args: pred_rots: pred_centroids: pred_z_vals: [B, 1] roi_cams: absolute cams roi_centers: roi_scales: roi_whs: (bw,bh) for bboxes eps: is_allo: z_type: REL | ABS | LOG | NEG_LOG Returns: """ if roi_cams.dim() == 2: roi_cams.unsqueeze_(0) assert roi_cams.dim() == 3, roi_cams.dim() # absolute coords c = torch.stack( [ (pred_centroids[:, 0] * roi_whs[:, 0]) + roi_centers[:, 0], (pred_centroids[:, 1] * roi_whs[:, 1]) + roi_centers[:, 1], ], dim=1, ) cx = c[:, 0:1] # [#roi, 1] cy = c[:, 1:2] # [#roi, 1] # unnormalize regressed z if z_type == "ABS": z = pred_z_vals elif z_type == "REL": # z_1 / z_2 = s_2 / s_1 ==> z_1 = s_2 / s_1 * z_2 z = pred_z_vals * resize_ratios.view(-1, 1) else: raise ValueError(f"Unknown z_type: {z_type}") # backproject regressed centroid with regressed z """ fx * tx + px * tz = z * cx fy * ty + py * tz = z * cy tz = z ==> fx * tx / tz = cx - px fy * ty / tz = cy - py ==> tx = (cx - px) * tz / fx ty = (cy - py) * tz / fy """ # NOTE: z must be [B,1] translation = torch.cat( [z * (cx - roi_cams[:, 0:1, 2]) / roi_cams[:, 0:1, 0], z * (cy - roi_cams[:, 1:2, 2]) / roi_cams[:, 1:2, 1], z], dim=1, ) if pred_rots.ndim == 2 and pred_rots.shape[-1] == 4: pred_quats = pred_rots quat_allo = pred_quats / (torch.norm(pred_quats, dim=1, keepdim=True) + eps) if is_allo: quat_ego = allocentric_to_egocentric_torch(translation, quat_allo, eps=eps) else: quat_ego = quat_allo rot_ego = quat2mat_torch(quat_ego) if pred_rots.ndim == 3 and pred_rots.shape[-1] == 3: # Nx3x3 if is_allo: rot_ego = allo_to_ego_mat_torch(translation, pred_rots, eps=eps) else: rot_ego = pred_rots return rot_ego, translation