def sample(self, mean, logvar): batch_size = mean.shape[0] mean = mean.view(batch_size * self._head_count, mean.shape[1] // self._head_count) mean_matrix = self.make_matrix(mean[:, :2], mean[:, 2:]) log_mean = SE2.log(SE2.from_matrix(mean_matrix, normalize=False)) if log_mean.dim() < 2: log_mean = log_mean[None] logvar = logvar.view(batch_size * self._head_count, logvar.shape[1] // self._head_count) inverse_sigma_matrix = self.get_inverse_sigma_matrix(logvar) inverse_covariance_matrix = torch.bmm( inverse_sigma_matrix.transpose(1, 2), inverse_sigma_matrix) result_inverse_covariance_matrix = torch.sum( inverse_covariance_matrix.reshape(-1, self._head_count, 3, 3), dim=1) result_covariance_matrix = torch.inverse( result_inverse_covariance_matrix) factors = torch.bmm( result_covariance_matrix.repeat_interleave(self._head_count, 0), inverse_covariance_matrix) scaled_log_mean = torch.bmm(factors, log_mean[:, :, None])[:, :, 0] result_log_mean = torch.sum(scaled_log_mean.reshape( -1, self._head_count, 3), dim=1) mean_matrix = SE2.exp(result_log_mean).as_matrix() if mean_matrix.dim() < 3: mean_matrix = mean_matrix[None] try: # inverse_sigma_matrix = torch.cholesky(result_inverse_covariance_matrix) # sigma_matrix = torch.inverse(inverse_sigma_matrix) sigma_matrix = torch.cholesky(result_covariance_matrix + torch.eye(3, device=mean.device) * 1e-4) except RuntimeError as msg: print(inverse_covariance_matrix) print(result_inverse_covariance_matrix) print(result_covariance_matrix) print("Cholesky error", msg) sigma_matrix = (torch.eye(3, device=mean.device) * 1e4).expand( batch_size, 3, 3) epsilon = torch.randn(batch_size, 3, device=mean.device) delta = torch.bmm(sigma_matrix, epsilon[:, :, None])[:, :, 0] delta_matrix = SE2.exp(delta).as_matrix() if delta_matrix.dim() < 3: delta_matrix = delta_matrix[None] position_matrix = torch.bmm(mean_matrix, delta_matrix) positions = torch.zeros(batch_size, 3) positions[:, 0] = position_matrix[:, 0, 2] positions[:, 1] = position_matrix[:, 1, 2] positions[:, 2] = torch.atan2(position_matrix[:, 1, 0], position_matrix[:, 0, 0]) positions = positions.cpu().detach().numpy() return positions
def test_perturb_batch(): T = SE2.exp(0.1 * torch.Tensor([[1, 2, 3], [4, 5, 6]])) T_copy1 = copy.deepcopy(T) T_copy2 = copy.deepcopy(T) xi = torch.Tensor([0.3, 0.2, 0.1]) T_copy1.perturb(xi) assert utils.allclose(T_copy1.as_matrix(), (SE2.exp(xi).dot(T)).as_matrix()) xis = torch.Tensor([[0.3, 0.2, 0.1], [-0.1, -0.2, -0.3]]) T_copy2.perturb(xis) assert utils.allclose(T_copy2.as_matrix(), (SE2.exp(xis).dot(T)).as_matrix())
def mean_position(self, mean, logvar): batch_size = mean.shape[0] mean = mean.view(batch_size * self._head_count, mean.shape[1] // self._head_count) mean_matrix = self.make_matrix(mean[:, :2], mean[:, 2:]) log_mean = SE2.log(SE2.from_matrix(mean_matrix, normalize=False)) if log_mean.dim() < 2: log_mean = log_mean[None] logvar = logvar.view(batch_size * self._head_count, logvar.shape[1] // self._head_count) inverse_sigma_matrix = self.get_inverse_sigma_matrix(logvar) inverse_covariance_matrix = torch.bmm( inverse_sigma_matrix.transpose(1, 2), inverse_sigma_matrix) result_inverse_covariance_matrix = torch.sum( inverse_covariance_matrix.reshape(-1, self._head_count, 3, 3), dim=1) result_covariance_matrix = torch.inverse( result_inverse_covariance_matrix) factors = torch.bmm( result_covariance_matrix.repeat_interleave(self._head_count, 0), inverse_covariance_matrix) scaled_log_mean = torch.bmm(factors, log_mean[:, :, None])[:, :, 0] result_log_mean = torch.sum(scaled_log_mean.reshape( -1, self._head_count, 3), dim=1) mean_matrix = SE2.exp(result_log_mean).as_matrix() if mean_matrix.dim() < 3: mean_matrix = mean_matrix[None] positions = torch.zeros(batch_size, 2, device=mean.device) positions[:, 0] = mean_matrix[:, 0, 2] positions[:, 1] = mean_matrix[:, 1, 2] return positions
def test_normalize_batch(): T = SE2.exp(0.1 * torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) assert SE2.is_valid_matrix(T.as_matrix()).all() T.rot.mat.add_(0.1) assert (SE2.is_valid_matrix(T.as_matrix()) == torch.ByteTensor([0, 0, 0])).all() T.normalize(inds=[0, 2]) assert (SE2.is_valid_matrix(T.as_matrix()) == torch.ByteTensor([1, 0, 1])).all() T.normalize() assert SE2.is_valid_matrix(T.as_matrix()).all()
def sample(self, mean, logvar): mean_matrix = self.make_matrix( mean[:, 0:2], torch.nn.functional.normalize(mean[:, 2:4])) if logvar.dim() < 2: logvar = logvar[None].expand(mean.shape[0], logvar.shape[0]) sigma_matrix = self.get_sigma_matrix(logvar) epsilon = torch.randn(mean.shape[0], 3, device=mean.device) delta = torch.bmm(sigma_matrix, epsilon[:, :, None])[:, :, 0] delta_matrix = SE2.exp(delta).as_matrix() if delta_matrix.dim() < 3: delta_matrix = delta_matrix[None] position_matrix = torch.bmm(mean_matrix, delta_matrix) positions = torch.zeros(mean.shape[0], 3) positions[:, 0] = position_matrix[:, 0, 2] positions[:, 1] = position_matrix[:, 1, 2] positions[:, 2] = torch.atan2(position_matrix[:, 1, 0], position_matrix[:, 0, 0]) positions = positions.cpu().detach().numpy() return positions
def test_adjoint_batch(): T = SE2.exp(0.1 * torch.Tensor([[1, 2, 3], [4, 5, 6]])) assert T.adjoint().shape == (2, 3, 3)
def test_adjoint(): T = SE2.exp(torch.Tensor([1, 2, 3])) assert T.adjoint().shape == (3, 3)
def test_inv_batch(): T = SE2.exp(0.1 * torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) assert utils.allclose( T.dot(T.inv()).as_matrix(), SE2.identity(T.trans.shape[0]).as_matrix())
def test_inv(): T = SE2.exp(torch.Tensor([1, 2, 3])) assert utils.allclose((T.dot(T.inv())).as_matrix(), torch.eye(3))
def test_normalize(): T = SE2.exp(torch.Tensor([1, 2, 3])) T.rot.mat.add_(0.1) T.normalize() assert SE2.is_valid_matrix(T.as_matrix()).all()
def test_perturb(): T = SE2.exp(torch.Tensor([1, 2, 3])) T_copy = copy.deepcopy(T) xi = torch.Tensor([0.3, 0.2, 0.1]) T.perturb(xi) assert utils.allclose(T.as_matrix(), (SE2.exp(xi).dot(T_copy)).as_matrix())
def test_exp_log_batch(): T = SE2.exp(0.1 * torch.Tensor([[1, 2, 3], [4, 5, 6]])) assert utils.allclose(SE2.exp(SE2.log(T)).as_matrix(), T.as_matrix())
def test_exp_log(): T = SE2.exp(torch.Tensor([1, 2, 3])) assert utils.allclose(SE2.exp(SE2.log(T)).as_matrix(), T.as_matrix())
def test_exp_log(): T = SE2.exp(torch.Tensor([1, 2, 3])) print(T.trans) print(T.rot.to_angle()) assert utils.allclose(SE2.exp(SE2.log(T)).as_matrix(), T.as_matrix())