Ejemplo n.º 1
0
 def mean_position(self, mean, logvar):
     batch_size = mean.shape[0]
     mean = mean.view(batch_size * self._head_count,
                      mean.shape[1] // self._head_count)
     mean_matrix = self.make_matrix(mean[:, :2], mean[:, 2:])
     log_mean = SE2.log(SE2.from_matrix(mean_matrix, normalize=False))
     if log_mean.dim() < 2:
         log_mean = log_mean[None]
     logvar = logvar.view(batch_size * self._head_count,
                          logvar.shape[1] // self._head_count)
     inverse_sigma_matrix = self.get_inverse_sigma_matrix(logvar)
     inverse_covariance_matrix = torch.bmm(
         inverse_sigma_matrix.transpose(1, 2), inverse_sigma_matrix)
     result_inverse_covariance_matrix = torch.sum(
         inverse_covariance_matrix.reshape(-1, self._head_count, 3, 3),
         dim=1)
     result_covariance_matrix = torch.inverse(
         result_inverse_covariance_matrix)
     factors = torch.bmm(
         result_covariance_matrix.repeat_interleave(self._head_count, 0),
         inverse_covariance_matrix)
     scaled_log_mean = torch.bmm(factors, log_mean[:, :, None])[:, :, 0]
     result_log_mean = torch.sum(scaled_log_mean.reshape(
         -1, self._head_count, 3),
                                 dim=1)
     mean_matrix = SE2.exp(result_log_mean).as_matrix()
     if mean_matrix.dim() < 3:
         mean_matrix = mean_matrix[None]
     positions = torch.zeros(batch_size, 2, device=mean.device)
     positions[:, 0] = mean_matrix[:, 0, 2]
     positions[:, 1] = mean_matrix[:, 1, 2]
     return positions
Ejemplo n.º 2
0
    def sample(self, mean, logvar):
        batch_size = mean.shape[0]
        mean = mean.view(batch_size * self._head_count,
                         mean.shape[1] // self._head_count)
        mean_matrix = self.make_matrix(mean[:, :2], mean[:, 2:])
        log_mean = SE2.log(SE2.from_matrix(mean_matrix, normalize=False))
        if log_mean.dim() < 2:
            log_mean = log_mean[None]
        logvar = logvar.view(batch_size * self._head_count,
                             logvar.shape[1] // self._head_count)
        inverse_sigma_matrix = self.get_inverse_sigma_matrix(logvar)
        inverse_covariance_matrix = torch.bmm(
            inverse_sigma_matrix.transpose(1, 2), inverse_sigma_matrix)
        result_inverse_covariance_matrix = torch.sum(
            inverse_covariance_matrix.reshape(-1, self._head_count, 3, 3),
            dim=1)
        result_covariance_matrix = torch.inverse(
            result_inverse_covariance_matrix)
        factors = torch.bmm(
            result_covariance_matrix.repeat_interleave(self._head_count, 0),
            inverse_covariance_matrix)
        scaled_log_mean = torch.bmm(factors, log_mean[:, :, None])[:, :, 0]
        result_log_mean = torch.sum(scaled_log_mean.reshape(
            -1, self._head_count, 3),
                                    dim=1)
        mean_matrix = SE2.exp(result_log_mean).as_matrix()
        if mean_matrix.dim() < 3:
            mean_matrix = mean_matrix[None]

        try:
            # inverse_sigma_matrix = torch.cholesky(result_inverse_covariance_matrix)
            # sigma_matrix = torch.inverse(inverse_sigma_matrix)
            sigma_matrix = torch.cholesky(result_covariance_matrix +
                                          torch.eye(3, device=mean.device) *
                                          1e-4)
        except RuntimeError as msg:
            print(inverse_covariance_matrix)
            print(result_inverse_covariance_matrix)
            print(result_covariance_matrix)
            print("Cholesky error", msg)
            sigma_matrix = (torch.eye(3, device=mean.device) * 1e4).expand(
                batch_size, 3, 3)

        epsilon = torch.randn(batch_size, 3, device=mean.device)
        delta = torch.bmm(sigma_matrix, epsilon[:, :, None])[:, :, 0]
        delta_matrix = SE2.exp(delta).as_matrix()
        if delta_matrix.dim() < 3:
            delta_matrix = delta_matrix[None]
        position_matrix = torch.bmm(mean_matrix, delta_matrix)
        positions = torch.zeros(batch_size, 3)
        positions[:, 0] = position_matrix[:, 0, 2]
        positions[:, 1] = position_matrix[:, 1, 2]
        positions[:, 2] = torch.atan2(position_matrix[:, 1, 0],
                                      position_matrix[:, 0, 0])
        positions = positions.cpu().detach().numpy()
        return positions
Ejemplo n.º 3
0
    def log_prob(self, value, mean, logvar):
        if logvar.dim() < 2:
            logvar = logvar[None].expand(mean.shape[0], logvar.shape[0])
        value_matrix = self.make_matrix(
            value[0], torch.nn.functional.normalize(value[1]))
        rotation = torch.nn.functional.normalize(mean[:, 2:4])
        mean_matrix = self.make_matrix(mean[:, 0:2],
                                       rotation).expand_as(value_matrix)
        delta_matrix = torch.bmm(self.pose_matrix_inverse(mean_matrix),
                                 value_matrix)
        delta_log = SE2.log(SE2.from_matrix(delta_matrix, normalize=False))
        if delta_log.dim() < 2:
            delta_log = delta_log[None]
        inverse_sigma_matrix = self.get_inverse_sigma_matrix(logvar).expand(
            delta_log.shape[0], delta_log.shape[1], delta_log.shape[1])
        delta_log = torch.bmm(inverse_sigma_matrix, delta_log[:, :,
                                                              None])[:, :, 0]
        log_determinant = self.get_logvar_determinant(logvar)

        log_prob = torch.sum(delta_log**2 / 2.,
                             dim=1) + 0.5 * log_determinant + 3 * math.log(
                                 math.sqrt(2 * math.pi))
        return log_prob
Ejemplo n.º 4
0
def test_exp_log_batch():
    T = SE2.exp(0.1 * torch.Tensor([[1, 2, 3], [4, 5, 6]]))
    assert utils.allclose(SE2.exp(SE2.log(T)).as_matrix(), T.as_matrix())
Ejemplo n.º 5
0
def test_exp_log():
    T = SE2.exp(torch.Tensor([1, 2, 3]))
    assert utils.allclose(SE2.exp(SE2.log(T)).as_matrix(), T.as_matrix())
Ejemplo n.º 6
0
def test_exp_log():
    T = SE2.exp(torch.Tensor([1, 2, 3]))
    print(T.trans)
    print(T.rot.to_angle())
    assert utils.allclose(SE2.exp(SE2.log(T)).as_matrix(), T.as_matrix())