def compute_geo_costs(rot, trans, Ex, Kinv, hp0, hp1, tau, Kinv_n=None): if Kinv_n is None: Kinv_n = Kinv R01 = kornia.angle_axis_to_rotation_matrix(rot) H01 = Kinv.inverse().matmul(R01).matmul(Kinv_n) comp_hp1 = H01.matmul(hp1.permute(0, 2, 1)) foe = (comp_hp1 - tau * hp0.permute(0, 2, 1)) parallax3d = Kinv.matmul(foe) p3dmag = parallax3d.norm(2, 1)[:, np.newaxis] parallax2d = (comp_hp1 / comp_hp1[:, -1:] - hp0.permute(0, 2, 1))[:, :2] p2dmag = parallax2d.norm(2, 1)[:, np.newaxis] p2dnorm = parallax2d / (1e-9 + p2dmag) foe_cam = Kinv.inverse().matmul(trans[:, :, np.newaxis]) foe_cam = foe_cam[:, :2] / (1e-9 + foe_cam[:, -1:]) direct = foe_cam - hp0.permute(0, 2, 1)[:, :2] directn = direct / (1e-9 + direct.norm(2, 1)[:, np.newaxis]) # cost metrics: 0) R-homography+symterr; 1) sampson 2) 2D angular (P+P) 3) 3D distance 4) 3D angular (P+P) ##TODO validate comp_hp0 = H01.inverse().matmul(hp0.permute(0, 2, 1)) mcost00 = parallax2d.norm(2, 1) mcost01 = (comp_hp0 / comp_hp0[:, -1:] - hp1.permute(0, 2, 1))[:, :2].norm( 2, 1) mcost1 = sampson_err(Kinv.matmul(hp0.permute(0, 2, 1)), Kinv_n.matmul(hp1.permute(0, 2, 1)), Ex.cuda().permute(0, 2, 1)) # variable K mcost2 = -(trans[:, -1:, np.newaxis]).sign() * (directn * p2dnorm).sum( 1, keepdims=True) mcost4 = -(trans[:, :, np.newaxis] * parallax3d).sum(1, keepdims=True) / ( p3dmag + 1e-9) mcost3 = torch.clamp(1 - mcost4.pow(2), 0, 1).sqrt() * p3dmag * mcost4.sign() mcost10 = torch.clamp(1 - mcost2.pow(2), 0, 1).sqrt() * p2dmag * mcost2.sign() return mcost00, mcost01, mcost1, mcost2, mcost3, mcost4, p3dmag, mcost10
def generate_scene(num_views: int, num_points: int) -> Dict[str, torch.Tensor]: # Generate the 3d points points3d = torch.rand(1, num_points, 3) # NxMx3 # Create random camera matrix K = epipolar.random_intrinsics(0.0, 100.0) # 1x3x3 # Create random rotation per view ang = torch.rand(num_views, 1) * kornia.pi * 2.0 rvec = torch.rand(num_views, 3) rvec = ang * rvec / torch.norm(rvec, dim=1, keepdim=True) # Nx3 rot_mat = kornia.angle_axis_to_rotation_matrix(rvec) # Nx3x3 # matches with cv2.Rodrigues -> yay ! # Create random translation per view tx = torch.empty(num_views).uniform_(-0.5, 0.5) ty = torch.empty(num_views).uniform_(-0.5, 0.5) tz = torch.empty(num_views).uniform_(-1.0, 2.0) tvec = torch.stack([tx, ty, tz], dim=1)[..., None] # Make sure the shape is in front of the camera points3d_trans = (rot_mat @ points3d.transpose(-2, -1)) + tvec min_dist = torch.min(points3d_trans[:, 2], dim=1)[0] tvec[:, 2, 0] = torch.where(min_dist < 0, tz - min_dist + 1.0, tz) # compute projection matrices P = epipolar.projection_from_KRt(K, rot_mat, tvec) # project points3d and backproject to image plane points2d = kornia.transform_points(P, points3d.expand(num_views, -1, -1)) return dict(K=K, R=rot_mat, t=tvec, P=P, points3d=points3d, points2d=points2d)
def get_projective_transform(center: torch.Tensor, angles: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: r"""Calculate the projection matrix for a 3D rotation. .. warning:: This API signature it is experimental and might suffer some changes in the future. The function computes the projection matrix given the center and angles per axis. Args: center: center of the rotation (x,y,z) in the source with shape :math:`(B, 3)`. angles: angle axis vector containing the rotation angles in degrees in the form of (rx, ry, rz) with shape :math:`(B, 3)`. Internally it calls Rodrigues to compute the rotation matrix from axis-angle. scales: scale factor for x-y-z-directions with shape :math:`(B, 3)`. Returns: the projection matrix of 3D rotation with shape :math:`(B, 3, 4)`. .. note:: This function is often used in conjunction with :func:`warp_affine3d`. """ if not (len(center.shape) == 2 and center.shape[-1] == 3): raise AssertionError(center.shape) if not (len(angles.shape) == 2 and angles.shape[-1] == 3): raise AssertionError(angles.shape) if center.device != angles.device: raise AssertionError(center.device, angles.device) if center.dtype != angles.dtype: raise AssertionError(center.dtype, angles.dtype) # create rotation matrix angle_axis_rad: torch.Tensor = K.deg2rad(angles) rmat: torch.Tensor = K.angle_axis_to_rotation_matrix( angle_axis_rad) # Bx3x3 scaling_matrix: torch.Tensor = K.eye_like(3, rmat) scaling_matrix = scaling_matrix * scales.unsqueeze(dim=1) rmat = rmat @ scaling_matrix.to(rmat) # define matrix to move forth and back to origin from_origin_mat = torch.eye(4)[None].repeat(rmat.shape[0], 1, 1).type_as(center) # Bx4x4 from_origin_mat[..., :3, -1] += center to_origin_mat = from_origin_mat.clone() to_origin_mat = _torch_inverse_cast(from_origin_mat) # append translation with zeros proj_mat = projection_from_Rt(rmat, torch.zeros_like(center)[..., None]) # Bx3x4 # chain 4x4 transforms proj_mat = convert_affinematrix_to_homography3d(proj_mat) # Bx4x4 proj_mat = from_origin_mat @ proj_mat @ to_origin_mat return proj_mat[..., :3, :] # Bx3x4
def aa2matrot(pose): ''' :param Nx1xnum_jointsx3 :return: pose_matrot: Nx1xnum_jointsx9 ''' batch_size = pose.size(0) pose_body_matrot = kornia.angle_axis_to_rotation_matrix( pose.reshape(-1, 3))[:, :3, :3].contiguous().view( batch_size, 1, -1, 9) return pose_body_matrot
def getDiffMap_tensor(hyp, objectCoordinates, sampling, camMat): # obj_cam = torch.zeros([height, width, 3]).to(DEVICE).double() RotMrx = kornia.angle_axis_to_rotation_matrix(hyp[0, :3].view(1, -1)).view( [3, 3]).float() f = camMat[0, 0] ppx = camMat[0, 2] ppy = camMat[1, 2] # trans = hyp[3:].view([-1, 1]).to(DEVICE) # object_tensor = objectCoordinates.view([1600, 3, 1]).to(DEVICE).double() # Rotmrx_tensor = RotMrx.unsqueeze(0).repeat(1600, 1, 1).to(DEVICE).double() # print('hyp:', hyp[0, 3] * torch.ones([40, 40, 1])) # trans_mrx = torch.cat((hyp[0, 3] * torch.ones([40, 40, 1]), # hyp[0, 4] * torch.ones([40, 40, 1]), # hyp[0, 5] * torch.ones([40, 40, 1])), 2) # print('trans:', trans_mrx) # obj_cam_tensor = torch.bmm(Rotmrx_tensor, object_tensor).view([40, 40, 3]).to(DEVICE) + trans_mrx obj_cam_tensor = torch.matmul(objectCoordinates, torch.transpose(RotMrx, 0, 1)) obj_cam_tensor[:, :, 0] += hyp[0, 3] obj_cam_tensor[:, :, 1] += hyp[0, 4] obj_cam_tensor[:, :, 2] += hyp[0, 5] # print('trans_mrx:', trans_mrx.shape) # print('rot:', RotMrx.shape) # obj_cam_tensor = torch.matmul(RotMrx, object_tensor) + trans_mrx # for i in range(width * height): # x = i % width # y = i // width # # obj_cam[y, x, :] = obj2cam(objectPoints=objectCoordinates[y, x, :], # trans=trans, # RotMrx=RotMrx) # print('obj_cam:', obj_cam[y, x, :]) # print('obj_cam_tensor:', obj_cam_tensor[y, x, :]) project_x = f * obj_cam_tensor[:, :, 0] / obj_cam_tensor[:, :, 2] + ppx project_y = f * obj_cam_tensor[:, :, 1] / obj_cam_tensor[:, :, 2] + ppy reproject_x = project_x - sampling[:, :, 0] reproject_y = project_y - sampling[:, :, 1] diff = torch.sqrt(reproject_x**2 + reproject_y**2) diffMap = torch.clamp(diff, 0., 100.) diffMap = torch.unsqueeze(diffMap, 0) # diff = torch.zeros([1, height, width]).to(DEVICE) # diff[0, :, :] = torch.sqrt(reproject_x ** 2 + reproject_y ** 2).to(DEVICE) # diffMap = torch.where(diff > 100., torch.full_like(diff, 100.), diff).requires_grad_(True) # diffMap[0, y, x] = getdiff_tensor(objectPoints=objectCoordinates[y, x, :], # hyps=hyp, # pixelx=sampling[y, x][0], # pixely=sampling[y, x][1], # RotMrx=RotMrx, # f=f, ppx=ppx, ppy=ppy).to(DEVICE) return diffMap
def get_skew_mat(transx, rotx): rot = kornia.angle_axis_to_rotation_matrix(rotx) trans = -rot.permute(0, 2, 1).matmul(transx[:, :, np.newaxis])[:, :, 0] rot = rot.permute(0, 2, 1) tx = torch.zeros(transx.shape[0], 3, 3) tx[:, 0, 1] = -transx[:, 2] tx[:, 0, 2] = transx[:, 1] tx[:, 1, 0] = transx[:, 2] tx[:, 1, 2] = -transx[:, 0] tx[:, 2, 0] = -transx[:, 1] tx[:, 2, 1] = transx[:, 0] return rot.matmul(tx)
def get_res(pts2d, pts3d, K, P): n = pts2d.size(0) m = 6 feas1 = P[0,m-1].item() > 0 R = kn.angle_axis_to_rotation_matrix(P[0, 0:m - 3].view(1, 3)) P = torch.cat((R[0, 0:3, 0:3].view(3, 3), P[0, m - 3:m].view(3, 1)), dim=-1) pts3d_h = torch.cat((pts3d,torch.ones(n,1,device=pts3d.device)), dim=-1) pts3d_cam = pts3d_h.mm(P.transpose(0, 1)) feas2 = (pts3d_cam[:,2].min().item() >= 0) feas = feas1 and feas2 pts2d_proj = pts3d_cam.mm(K.transpose(0, 1)) S = pts2d_proj[:, 2].view(n, 1) res = pts2d - pts2d_proj[:, 0:2].div(S) return torch.norm(res,dim=1).sum().item(), feas
def test_triplet_amq(self, axis, device, dtype, atol, rtol): array = [[0.0, 0.0, 0.0]] array[0][axis] = kornia.pi / 2.0 angle_axis = torch.tensor(array, device=device, dtype=dtype) assert angle_axis.shape[-1] == 3 rot_m = kornia.angle_axis_to_rotation_matrix(angle_axis) assert rot_m.shape[-1] == 3 assert rot_m.shape[-2] == 3 quaternion = kornia.rotation_matrix_to_quaternion(rot_m, order=QuaternionCoeffOrder.WXYZ) assert quaternion.shape[-1] == 4 angle_axis_hat = kornia.quaternion_to_angle_axis(quaternion, order=QuaternionCoeffOrder.WXYZ) assert_close(angle_axis_hat, angle_axis, atol=atol, rtol=rtol)
def test_angle_axis_to_rotation_matrix(batch_size, device, dtype): # generate input data angle_axis = torch.rand(batch_size, 3, device=device, dtype=dtype) eye_batch = create_eye_batch(batch_size, 3, device=device, dtype=dtype) # apply transform rotation_matrix = kornia.angle_axis_to_rotation_matrix(angle_axis) rotation_matrix_eye = torch.matmul(rotation_matrix, rotation_matrix.transpose(1, 2)) assert_allclose(rotation_matrix_eye, eye_batch, atol=1e-4, rtol=1e-4) # evaluate function gradient angle_axis = tensor_to_gradcheck_var(angle_axis) # to var assert gradcheck(kornia.angle_axis_to_rotation_matrix, (angle_axis, ), raise_exception=True)
def batch_project(P, pts3d, K, angle_axis=True): n = pts3d.size(0) bs = P.size(0) device = P.device pts3d_h = torch.cat((pts3d, torch.ones(n, 1, device=device)), dim=-1) if angle_axis: R_out = kn.angle_axis_to_rotation_matrix(P[:, 0:3].view(bs, 3)) PM = torch.cat((R_out[:,0:3,0:3], P[:, 3:6].view(bs, 3, 1)), dim=-1) else: PM = P pts3d_cam = pts3d_h.matmul(PM.transpose(1,2)) pts2d_proj = pts3d_cam.matmul(K.t()) S = pts2d_proj[:,:, 2].view(bs, n, 1) pts2d_pro = pts2d_proj[:,:,0:2].div(S) return pts2d_pro
def test_angle_axis_to_rotation_matrix(batch_size, device_type): # generate input data device = torch.device(device_type) angle_axis = torch.rand(batch_size, 3).to(device) eye_batch = utils.create_eye_batch(batch_size, 4).to(device) # apply transform rotation_matrix = kornia.angle_axis_to_rotation_matrix(angle_axis) rotation_matrix_eye = torch.matmul( rotation_matrix, rotation_matrix.transpose(1, 2)) assert check_equal_torch(rotation_matrix_eye, eye_batch) # evaluate function gradient angle_axis = utils.tensor_to_gradcheck_var(angle_axis) # to var assert gradcheck(kornia.angle_axis_to_rotation_matrix, (angle_axis,), raise_exception=True)
def test_triplet_qam_xyzw(self, axis, device, dtype, atol, rtol): array = [[0.0, 0.0, 0.0, 0.0]] array[0][axis] = 1.0 quaternion = torch.tensor(array, device=device, dtype=dtype) assert quaternion.shape[-1] == 4 with pytest.warns(UserWarning): angle_axis = kornia.quaternion_to_angle_axis(quaternion, order=QuaternionCoeffOrder.XYZW) assert angle_axis.shape[-1] == 3 rot_m = kornia.angle_axis_to_rotation_matrix(angle_axis) assert rot_m.shape[-1] == 3 assert rot_m.shape[-2] == 3 with pytest.warns(UserWarning): quaternion_hat = kornia.rotation_matrix_to_quaternion(rot_m, order=QuaternionCoeffOrder.XYZW) assert_close(quaternion_hat, quaternion, atol=atol, rtol=rtol)
def get_projective_transform(center: torch.Tensor, angles: torch.Tensor) -> torch.Tensor: r"""Calculates the projection matrix for a 3D rotation. The function computes the projection matrix given the center and angles per axis. Args: center (torch.Tensor): center of the rotation in the source with shape :math:`(B, 3)`. angles (torch.Tensor): angle axis vector containing the rotation angles in degrees in the form of (rx, ry, rz) with shape :math:`(B, 3)`. Internally it calls Rodrigues to compute the rotation matrix from axis-angle. Returns: torch.Tensor: the projection matrix of 3D rotation with shape :math:`(B, 3, 4)`. """ assert len(center.shape) == 2 and center.shape[-1] == 3, center.shape assert len(angles.shape) == 2 and angles.shape[-1] == 3, angles.shape assert center.device == angles.device, (center.device, angles.device) assert center.dtype == angles.dtype, (center.dtype, angles.dtype) # create rotation matrix angle_axis_rad: torch.Tensor = K.deg2rad(angles) rmat: torch.Tensor = K.angle_axis_to_rotation_matrix( angle_axis_rad) # Bx3x3 # define matrix to move forth and back to origin from_origin_mat = torch.eye(4)[None].repeat(rmat.shape[0], 1, 1).type_as(center) # Bx4x4 from_origin_mat[..., :3, -1] += center to_origin_mat = from_origin_mat.clone() to_origin_mat = from_origin_mat.inverse() # append tranlation with zeros proj_mat = projection_from_Rt(rmat, torch.zeros_like(center)[..., None]) # Bx3x4 # chain 4x4 transforms proj_mat = matrix_to_homogeneous(proj_mat) # Bx4x4 proj_mat = (from_origin_mat @ proj_mat @ to_origin_mat) return proj_mat[..., :3, :] # Bx3x4
def optimize_nodes( self, match_count: int, optimized_node_count: int, batch_edge_count: int, graph_nodes_i: torch.Tensor, source_anchors: torch.Tensor, source_weights: torch.Tensor, source_points_filtered: torch.Tensor, source_colors_filtered: torch.Tensor, correspondence_weights_filtered: torch.Tensor, xy_pixels_warped_filtered: torch.Tensor, target_matches_filtered: torch.Tensor, graph_edge_pairs_filtered: torch.Tensor, graph_edge_weights_pairs: torch.Tensor, num_neighbors: int, fx: torch.Tensor, fy: torch.Tensor, cx: torch.Tensor, cy: torch.Tensor, batch_convergence_info ) -> Tuple[bool, torch.Tensor, torch.Tensor, torch.Tensor]: self.vec_to_skew_mat.to(source_anchors.device) float_dtype = source_weights.dtype device = source_weights.device gauss_newton_iteration_count = self.gn_num_iter # The parameters in GN solver are 3 parameters for rotation and 3 parameters for # translation for every node. All node rotation parameters are listed first, and # then all node translation parameters are listed. # transform_delta = [rotations_current, translations_current] rotations_current = torch.eye(3, dtype=float_dtype, device=device).view(1, 3, 3).repeat( optimized_node_count, 1, 1) translations_current = torch.zeros((optimized_node_count, 3, 1), dtype=float_dtype, device=device) if self.gn_debug: print( f"\tMatch count: {match_count} || Node count: {optimized_node_count} || Edges count: {batch_edge_count}" ) # Initialize helper structures. data_increment_vec_0_3 = torch.arange(0, match_count * 3, 3, out=torch.cuda.LongTensor(), device=device) # (match_count) data_increment_vec_1_3 = torch.arange(1, match_count * 3, 3, out=torch.cuda.LongTensor(), device=device) # (match_count) data_increment_vec_2_3 = torch.arange(2, match_count * 3, 3, out=torch.cuda.LongTensor(), device=device) # (match_count) arap_increment_vec_0_3 = None arap_increment_vec_1_3 = None arap_increment_vec_2_3 = None arap_one_vec = None if batch_edge_count > 0: arap_increment_vec_0_3 = torch.arange( 0, batch_edge_count * 3, 3, out=torch.cuda.LongTensor(), device=device) # (batch_edge_count) arap_increment_vec_1_3 = torch.arange( 1, batch_edge_count * 3, 3, out=torch.cuda.LongTensor(), device=device) # (batch_edge_count) arap_increment_vec_2_3 = torch.arange( 2, batch_edge_count * 3, 3, out=torch.cuda.LongTensor(), device=device) # (batch_edge_count) arap_one_vec = torch.ones(batch_edge_count, dtype=float_dtype, device=device) ill_posed_system = False residuals = None for i_iteration in range(gauss_newton_iteration_count): residual_data, jacobian_data = \ self.compute_data_residual_and_jacobian(data_increment_vec_0_3, data_increment_vec_1_3, data_increment_vec_2_3, match_count, optimized_node_count, graph_nodes_i, source_anchors, source_weights, source_points_filtered, source_colors_filtered, correspondence_weights_filtered, xy_pixels_warped_filtered, target_matches_filtered, i_iteration, fx, fy, cx, cy, rotations_current, translations_current) loss_arap = None if batch_edge_count > 0: residual_arap, jacobian_arap = \ self.compute_arap_residual_and_jacobian(arap_increment_vec_0_3, arap_increment_vec_1_3, arap_increment_vec_2_3, arap_one_vec, graph_edge_pairs_filtered, graph_edge_weights_pairs, graph_nodes_i, batch_edge_count, optimized_node_count, num_neighbors, rotations_current, translations_current) loss_arap = torch.norm(residual_arap).item() residuals = torch.cat((residual_data, residual_arap), 0) jacobian = torch.cat((jacobian_data, jacobian_arap), 0) else: residuals = residual_data jacobian = jacobian_data success, transform_delta = self.solve_linear_system( residuals, jacobian, batch_convergence_info) ill_posed_system = not success if ill_posed_system: break # Increment the current rotation and translation. rotation_increments = kornia.angle_axis_to_rotation_matrix( transform_delta[:optimized_node_count * 3].view( optimized_node_count, 3)) translation_increments = transform_delta[optimized_node_count * 3:].view( optimized_node_count, 3, 1) rotations_current = torch.matmul(rotation_increments, rotations_current) translations_current = translations_current + translation_increments loss_data = torch.norm(residual_data).item() loss_total = torch.norm(residuals).item() batch_convergence_info["data"].append(loss_data) batch_convergence_info["total"].append(loss_total) if batch_edge_count > 0: batch_convergence_info["arap"].append(loss_arap) if self.gn_debug: if batch_edge_count > 0: print( f"\t\t-->Iteration: {i_iteration}. " f"Loss: \tdata = {loss_data:.3f}, \tarap = {loss_arap:.3f}, \ttotal = {loss_total:.3f}" ) else: print( f"\t\t-->Iteration: {i_iteration}. Loss: \tdata = {loss_data:.3f}, \ttotal = {loss_total:.3f}" ) return ill_posed_system, residuals, rotations_current, translations_current
def angle_axis_to_rot_matrix(euler): return kornia.angle_axis_to_rotation_matrix(angle_axis=euler)
def test_smoke(self, device): angle_axis = torch.zeros(3) rotation_matrix = kornia.angle_axis_to_rotation_matrix(angle_axis) assert rotation_matrix.shape == (3, 3)
def backward(ctx, grad_output): pts2d, P_6d, pts3d, K = ctx.saved_tensors device = pts2d.device bs = pts2d.size(0) n = pts2d.size(1) m = 6 grad_x = torch.zeros_like(pts2d) grad_z = torch.zeros_like(pts3d) grad_K = torch.zeros_like(K) for i in range(bs): J_fy = torch.zeros(m,m, device=device) J_fx = torch.zeros(m,2*n, device=device) J_fz = torch.zeros(m,3*n, device=device) J_fK = torch.zeros(m, 9, device=device) coefs = get_coefs(P_6d[i].view(1,6), pts3d, K) pts2d_flat = pts2d[i].clone().view(-1).detach().requires_grad_() P_6d_flat = P_6d[i].clone().view(-1).detach().requires_grad_() pts3d_flat = pts3d.clone().view(-1).detach().requires_grad_() K_flat = K.clone().view(-1).detach().requires_grad_() for j in range(m): torch.set_grad_enabled(True) if j > 0: pts2d_flat.grad.zero_() P_6d_flat.grad.zero_() pts3d_flat.grad.zero_() K_flat.grad.zero_() R = kn.angle_axis_to_rotation_matrix(P_6d_flat[0:m-3].view(1,3)) P = torch.cat((R[0,0:3,0:3].view(3,3), P_6d_flat[m-3:m].view(3,1)),dim=-1) KP = torch.mm(K_flat.view(3,3), P) pts2d_i = pts2d_flat.view(n,2).transpose(0,1) pts3d_i = torch.cat((pts3d_flat.view(n,3),torch.ones(n,1,device=device)),dim=-1).t() proj_i = KP.mm(pts3d_i) Si = proj_i[2,:].view(1,n) r = pts2d_i*Si-proj_i[0:2,:] coef = coefs[:,:,j].transpose(0,1) # size: [2,n] fj = (coef*r).sum() fj.backward() J_fy[j,:] = P_6d_flat.grad.clone() J_fx[j,:] = pts2d_flat.grad.clone() J_fz[j,:] = pts3d_flat.grad.clone() J_fK[j,:] = K_flat.grad.clone() inv_J_fy = torch.inverse(J_fy) J_yx = (-1) * torch.mm(inv_J_fy, J_fx) J_yz = (-1) * torch.mm(inv_J_fy, J_fz) J_yK = (-1) * torch.mm(inv_J_fy, J_fK) grad_x[i] = grad_output[i].view(1,m).mm(J_yx).view(n,2) grad_z += grad_output[i].view(1,m).mm(J_yz).view(n,3) grad_K += grad_output[i].view(1,m).mm(J_yK).view(3,3) return grad_x, grad_z, grad_K, None
def P6d2PM(P6d): bs = P6d.size(0) PM = kn.angle_axis_to_rotation_matrix(P6d[:,0:3].view(bs,3)) T = P6d[:,3:6].view(bs,3,1) PM = torch.cat((PM[:,0:3,0:3].view(bs,3,3),T),dim=-1) return PM