def test_from_matrix(): T_good = SE2.from_matrix(torch.eye(3)) assert isinstance(T_good, SE2) \ and isinstance(T_good.rot, SO2) \ and T_good.trans.shape == (2,) \ and SE2.is_valid_matrix(T_good.as_matrix()).all() T_bad = SE2.from_matrix(torch.eye(3).add_(1e-3), normalize=True) assert isinstance(T_bad, SE2) \ and isinstance(T_bad.rot, SO2) \ and T_bad.trans.shape == (2,) \ and SE2.is_valid_matrix(T_bad.as_matrix()).all()
def test_from_matrix_batch(): T_good = SE2.from_matrix(torch.eye(3).repeat(5, 1, 1)) assert isinstance(T_good, SE2) \ and T_good.trans.shape == (5, 2) \ and SE2.is_valid_matrix(T_good.as_matrix()).all() T_bad = T_good.as_matrix() T_bad[3, :, :].add_(0.1) T_bad = SE2.from_matrix(T_bad, normalize=True) assert isinstance(T_bad, SE2) \ and T_bad.trans.shape == (5, 2) \ and SE2.is_valid_matrix(T_bad.as_matrix()).all()
def mean_position(self, mean, logvar): batch_size = mean.shape[0] mean = mean.view(batch_size * self._head_count, mean.shape[1] // self._head_count) mean_matrix = self.make_matrix(mean[:, :2], mean[:, 2:]) log_mean = SE2.log(SE2.from_matrix(mean_matrix, normalize=False)) if log_mean.dim() < 2: log_mean = log_mean[None] logvar = logvar.view(batch_size * self._head_count, logvar.shape[1] // self._head_count) inverse_sigma_matrix = self.get_inverse_sigma_matrix(logvar) inverse_covariance_matrix = torch.bmm( inverse_sigma_matrix.transpose(1, 2), inverse_sigma_matrix) result_inverse_covariance_matrix = torch.sum( inverse_covariance_matrix.reshape(-1, self._head_count, 3, 3), dim=1) result_covariance_matrix = torch.inverse( result_inverse_covariance_matrix) factors = torch.bmm( result_covariance_matrix.repeat_interleave(self._head_count, 0), inverse_covariance_matrix) scaled_log_mean = torch.bmm(factors, log_mean[:, :, None])[:, :, 0] result_log_mean = torch.sum(scaled_log_mean.reshape( -1, self._head_count, 3), dim=1) mean_matrix = SE2.exp(result_log_mean).as_matrix() if mean_matrix.dim() < 3: mean_matrix = mean_matrix[None] positions = torch.zeros(batch_size, 2, device=mean.device) positions[:, 0] = mean_matrix[:, 0, 2] positions[:, 1] = mean_matrix[:, 1, 2] return positions
def sample(self, mean, logvar): batch_size = mean.shape[0] mean = mean.view(batch_size * self._head_count, mean.shape[1] // self._head_count) mean_matrix = self.make_matrix(mean[:, :2], mean[:, 2:]) log_mean = SE2.log(SE2.from_matrix(mean_matrix, normalize=False)) if log_mean.dim() < 2: log_mean = log_mean[None] logvar = logvar.view(batch_size * self._head_count, logvar.shape[1] // self._head_count) inverse_sigma_matrix = self.get_inverse_sigma_matrix(logvar) inverse_covariance_matrix = torch.bmm( inverse_sigma_matrix.transpose(1, 2), inverse_sigma_matrix) result_inverse_covariance_matrix = torch.sum( inverse_covariance_matrix.reshape(-1, self._head_count, 3, 3), dim=1) result_covariance_matrix = torch.inverse( result_inverse_covariance_matrix) factors = torch.bmm( result_covariance_matrix.repeat_interleave(self._head_count, 0), inverse_covariance_matrix) scaled_log_mean = torch.bmm(factors, log_mean[:, :, None])[:, :, 0] result_log_mean = torch.sum(scaled_log_mean.reshape( -1, self._head_count, 3), dim=1) mean_matrix = SE2.exp(result_log_mean).as_matrix() if mean_matrix.dim() < 3: mean_matrix = mean_matrix[None] try: # inverse_sigma_matrix = torch.cholesky(result_inverse_covariance_matrix) # sigma_matrix = torch.inverse(inverse_sigma_matrix) sigma_matrix = torch.cholesky(result_covariance_matrix + torch.eye(3, device=mean.device) * 1e-4) except RuntimeError as msg: print(inverse_covariance_matrix) print(result_inverse_covariance_matrix) print(result_covariance_matrix) print("Cholesky error", msg) sigma_matrix = (torch.eye(3, device=mean.device) * 1e4).expand( batch_size, 3, 3) epsilon = torch.randn(batch_size, 3, device=mean.device) delta = torch.bmm(sigma_matrix, epsilon[:, :, None])[:, :, 0] delta_matrix = SE2.exp(delta).as_matrix() if delta_matrix.dim() < 3: delta_matrix = delta_matrix[None] position_matrix = torch.bmm(mean_matrix, delta_matrix) positions = torch.zeros(batch_size, 3) positions[:, 0] = position_matrix[:, 0, 2] positions[:, 1] = position_matrix[:, 1, 2] positions[:, 2] = torch.atan2(position_matrix[:, 1, 0], position_matrix[:, 0, 0]) positions = positions.cpu().detach().numpy() return positions
def test_from_matrix(): T_good = SE2.from_matrix(torch.eye(3)) print(T_good.rot.mat) torch.allclose(tensor([ [1, 0], [0, 1], ], dtype=torch.float32), T_good.rot.mat) torch.allclose(tensor([0, 0], dtype=torch.float32), T_good.trans) assert isinstance(T_good, SE2) \ and isinstance(T_good.rot, SO2) \ and T_good.trans.shape == (2,)
def test_dot(): T = torch.Tensor([[0, -1, -0.5], [1, 0, 0.5], [0, 0, 1]]) T_SE2 = SE2.from_matrix(T) torch.allclose(tensor([np.pi / 2], dtype=torch.float32), T_SE2.rot.to_angle()) torch.allclose(tensor([-0.5, 0.5], dtype=torch.float32), T_SE2.trans) pt = torch.Tensor([1, 2]) Tpt_SE2 = T_SE2.dot(pt) torch.allclose(tensor([-2 - 0.5, 1 + 0.5], dtype=torch.float32), Tpt_SE2) pth = torch.Tensor([1, 2, 1]) Tpth_SE2 = T_SE2.dot(pth) torch.allclose(tensor([-2 - 0.5, 1 + 0.5, 1], dtype=torch.float32), Tpth_SE2)
def test_dot(): T = torch.Tensor([[0, -1, -0.5], [1, 0, 0.5], [0, 0, 1]]) T_SE2 = SE2.from_matrix(T) pt = torch.Tensor([1, 2]) pth = torch.Tensor([1, 2, 1]) TT = torch.mm(T, T) TT_SE2 = T_SE2.dot(T_SE2).as_matrix() assert utils.allclose(TT_SE2, TT) Tpt = torch.matmul(T[0:2, 0:2], pt) + T[0:2, 2] Tpt_SE2 = T_SE2.dot(pt) assert utils.allclose(Tpt_SE2, Tpt) Tpth = torch.matmul(T, pth) Tpth_SE2 = T_SE2.dot(pth) assert utils.allclose(Tpth_SE2, Tpth) and \ utils.allclose(Tpth_SE2[0:2], Tpt)
def log_prob(self, value, mean, logvar): if logvar.dim() < 2: logvar = logvar[None].expand(mean.shape[0], logvar.shape[0]) value_matrix = self.make_matrix( value[0], torch.nn.functional.normalize(value[1])) rotation = torch.nn.functional.normalize(mean[:, 2:4]) mean_matrix = self.make_matrix(mean[:, 0:2], rotation).expand_as(value_matrix) delta_matrix = torch.bmm(self.pose_matrix_inverse(mean_matrix), value_matrix) delta_log = SE2.log(SE2.from_matrix(delta_matrix, normalize=False)) if delta_log.dim() < 2: delta_log = delta_log[None] inverse_sigma_matrix = self.get_inverse_sigma_matrix(logvar).expand( delta_log.shape[0], delta_log.shape[1], delta_log.shape[1]) delta_log = torch.bmm(inverse_sigma_matrix, delta_log[:, :, None])[:, :, 0] log_determinant = self.get_logvar_determinant(logvar) log_prob = torch.sum(delta_log**2 / 2., dim=1) + 0.5 * log_determinant + 3 * math.log( math.sqrt(2 * math.pi)) return log_prob
def test_dot_batch(): T1 = torch.Tensor([[0, -1, -0.5], [1, 0, 0.5], [0, 0, 1]]).expand(5, 3, 3) T2 = torch.Tensor([[0, -1, -0.5], [1, 0, 0.5], [0, 0, 1]]) T1_SE2 = SE2.from_matrix(T1) T2_SE2 = SE2.from_matrix(T2) pt1 = torch.Tensor([1, 2]) pt2 = torch.Tensor([4, 5]) pt3 = torch.Tensor([7, 8]) pts = torch.cat( [pt1.unsqueeze(dim=0), pt2.unsqueeze(dim=0), pt3.unsqueeze(dim=0)], dim=0) # 3x2 ptsbatch = pts.unsqueeze(dim=0).expand(5, 3, 2) pt1h = torch.Tensor([1, 2, 1]) pt2h = torch.Tensor([4, 5, 1]) pt3h = torch.Tensor([7, 8, 1]) ptsh = torch.cat( [pt1h.unsqueeze(dim=0), pt2h.unsqueeze(dim=0), pt3h.unsqueeze(dim=0)], dim=0) # 3x3 ptshbatch = ptsh.unsqueeze(dim=0).expand(5, 3, 3) T1T1 = torch.bmm(T1, T1) T1T1_SE2 = T1_SE2.dot(T1_SE2).as_matrix() assert T1T1_SE2.shape == T1.shape and utils.allclose(T1T1_SE2, T1T1) T1T2 = torch.matmul(T1, T2) T1T2_SE2 = T1_SE2.dot(T2_SE2).as_matrix() assert T1T2_SE2.shape == T1.shape and utils.allclose(T1T2_SE2, T1T2) T1pt1 = torch.matmul(T1[:, 0:2, 0:2], pt1) + T1[:, 0:2, 2] T1pt1_SE2 = T1_SE2.dot(pt1) assert T1pt1_SE2.shape == (T1.shape[0], pt1.shape[0]) \ and utils.allclose(T1pt1_SE2, T1pt1) T1pt1h = torch.matmul(T1, pt1h) T1pt1h_SE2 = T1_SE2.dot(pt1h) assert T1pt1h_SE2.shape == (T1.shape[0], pt1h.shape[0]) \ and utils.allclose(T1pt1h_SE2, T1pt1h) \ and utils.allclose(T1pt1h_SE2[:, 0:2], T1pt1_SE2) T1pt2 = torch.matmul(T1[:, 0:2, 0:2], pt2) + T1[:, 0:2, 2] T1pt2_SE2 = T1_SE2.dot(pt2) assert T1pt2_SE2.shape == (T1.shape[0], pt2.shape[0]) \ and utils.allclose(T1pt2_SE2, T1pt2) T1pt2h = torch.matmul(T1, pt2h) T1pt2h_SE2 = T1_SE2.dot(pt2h) assert T1pt2h_SE2.shape == (T1.shape[0], pt2h.shape[0]) \ and utils.allclose(T1pt2h_SE2, T1pt2h) \ and utils.allclose(T1pt2h_SE2[:, 0:2], T1pt2_SE2) T1pts = torch.bmm(T1[:, 0:2, 0:2], pts.unsqueeze(dim=0).expand( T1.shape[0], pts.shape[0], pts.shape[1]).transpose(2, 1)).transpose(2, 1) + \ T1[:, 0:2, 2].unsqueeze(dim=1).expand( T1.shape[0], pts.shape[0], pts.shape[1]) T1pts_SE2 = T1_SE2.dot(pts) assert T1pts_SE2.shape == (T1.shape[0], pts.shape[0], pts.shape[1]) \ and utils.allclose(T1pts_SE2, T1pts) \ and utils.allclose(T1pt1, T1pts[:, 0, :]) \ and utils.allclose(T1pt2, T1pts[:, 1, :]) T1ptsh = torch.bmm( T1, ptsh.unsqueeze(dim=0).expand(T1.shape[0], ptsh.shape[0], ptsh.shape[1]).transpose(2, 1)).transpose( 2, 1) T1ptsh_SE2 = T1_SE2.dot(ptsh) assert T1ptsh_SE2.shape == (T1.shape[0], ptsh.shape[0], ptsh.shape[1]) \ and utils.allclose(T1ptsh_SE2, T1ptsh) \ and utils.allclose(T1pt1h, T1ptsh[:, 0, :]) \ and utils.allclose(T1pt2h, T1ptsh[:, 1, :]) \ and utils.allclose(T1ptsh_SE2[:, :, 0:2], T1pts_SE2) T1ptsbatch = torch.bmm(T1[:, 0:2, 0:2], ptsbatch.transpose(2, 1)).transpose(2, 1) + \ T1[:, 0:2, 2].unsqueeze(dim=1).expand(ptsbatch.shape) T1ptsbatch_SE2 = T1_SE2.dot(ptsbatch) assert T1ptsbatch_SE2.shape == ptsbatch.shape \ and utils.allclose(T1ptsbatch_SE2, T1ptsbatch) \ and utils.allclose(T1pt1, T1ptsbatch[:, 0, :]) \ and utils.allclose(T1pt2, T1ptsbatch[:, 1, :]) T1ptshbatch = torch.bmm(T1, ptshbatch.transpose(2, 1)).transpose(2, 1) T1ptshbatch_SE2 = T1_SE2.dot(ptshbatch) assert T1ptshbatch_SE2.shape == ptshbatch.shape \ and utils.allclose(T1ptshbatch_SE2, T1ptshbatch) \ and utils.allclose(T1pt1h, T1ptshbatch[:, 0, :]) \ and utils.allclose(T1pt2h, T1ptshbatch[:, 1, :]) \ and utils.allclose(T1ptshbatch_SE2[:, :, 0:2], T1ptsbatch_SE2) T2ptsbatch = torch.matmul(T2[0:2, 0:2], ptsbatch.transpose(2, 1)).transpose(2, 1) + \ T1[:, 0:2, 2].unsqueeze(dim=1).expand(ptsbatch.shape) T2ptsbatch_SE2 = T2_SE2.dot(ptsbatch) assert T2ptsbatch_SE2.shape == ptsbatch.shape \ and utils.allclose(T2ptsbatch_SE2, T2ptsbatch) \ and utils.allclose(T2_SE2.dot(pt1), T2ptsbatch[:, 0, :]) \ and utils.allclose(T2_SE2.dot(pt2), T2ptsbatch[:, 1, :]) T2ptshbatch = torch.matmul(T2, ptshbatch.transpose(2, 1)).transpose(2, 1) T2ptshbatch_SE2 = T2_SE2.dot(ptshbatch) assert T2ptshbatch_SE2.shape == ptshbatch.shape \ and utils.allclose(T2ptshbatch_SE2, T2ptshbatch) \ and utils.allclose(T2_SE2.dot(pt1h), T2ptshbatch[:, 0, :]) \ and utils.allclose(T2_SE2.dot(pt2h), T2ptshbatch[:, 1, :]) \ and utils.allclose(T2ptshbatch_SE2[:, :, 0:2], T2ptsbatch_SE2)