def se3_vee(x): """ Map Lie algebra in ordinary (3, 3) matrix rep to vector. Inverse of so3_hat :param x: Lie algebar in matrix rep of shape (..., 4, 4) :return: Lie algebra in vector rep of shape (..., 6 """ assert x.shape[-2:] == (4, 4) se3_alg, filler = x.split([3, 1], -2) so3_alg, r3_alg = se3_alg.split([3, 1], -1) v_so3 = so3_vee(so3_alg) return torch.cat([v_so3, r3_alg.squeeze(-1)], -1)
def se3_log(r): """ Logarithm map of SO(3). :param r: group element of shape (..., 4, 4) :return: Algebra element in matrix basis of shape (..., 4, 4) """ se3, filler = r.split([3, 1], -2) so3, r3 = se3.split([3, 1], -1) so3_alg = so3_log(so3) # print(so3, so3.shape) theta = so3_vee(so3_alg).norm(p=2, dim=-1, keepdim=True) # print(theta.shape, so3_alg.shape) K = so3_alg / theta.unsqueeze(-1) # convert nan into 0 mask = K != K K[mask] = 0 A = theta / torch.sin(theta) B = (1 - torch.cos(theta)) / (theta**2) mask = (theta < 1e-20).nonzero() # x/sin(x) -> 1 + x^2/6 as x->0 A[mask] = 1 + theta[mask]**2 / 6 # (1-cos(x))/x^2 -> 1/2 as x->0 B[mask] = 1 / 2 eye = torch.eye(3, device=r.device, dtype=r.dtype) Vinv = eye + so3_alg / 2 + (1 - A / (2 * B))[..., None] * (K @ K) # print(theta.shape) # print(((1 - theta*torch.sin(theta)/(2-2*torch.cos(theta)))/theta**2)[..., None]) r3_alg = Vinv @ r3 return se3_fill(so3_alg, r3_alg.squeeze(-1), "alg")
def _inverse_set(self, y): return self._xset(so3_vee(so3_log(y)))
def _inverse(self, y): return so3_vee(so3_log(y))