def backward(self, grad_output):
        phi, R = self.saved_tensors
        grad = grad_output.new_empty((3, 3, 3))
        e_0 = grad_output.new_tensor([1, 0, 0]).view(3, 1)
        e_1 = grad_output.new_tensor([0, 1, 0]).view(3, 1)
        e_2 = grad_output.new_tensor([0, 0, 1]).view(3, 1)
        I = grad_output.new_empty((3, 3))
        torch.nn.init.eye_(I)

        if phi.norm() < 1e-8:
            grad[0, :, :] = SO3.wedge(e_0)
            grad[1, :, :] = SO3.wedge(e_1)
            grad[2, :, :] = SO3.wedge(e_2)
        else:

            fact = 1. / (phi.norm()**2)
            phi_wedge = SO3.wedge(phi)
            ImR = (I - R)

            grad[0, :, :] = fact * (phi[0] * phi_wedge +
                                    SO3.wedge(phi_wedge.mm(ImR.mm(e_0)))).mm(R)
            grad[1, :, :] = fact * (phi[1] * phi_wedge +
                                    SO3.wedge(phi_wedge.mm(ImR.mm(e_1)))).mm(R)
            grad[2, :, :] = fact * (phi[2] * phi_wedge +
                                    SO3.wedge(phi_wedge.mm(ImR.mm(e_2)))).mm(R)

        out = (grad_output * grad).sum((1, 2)).view(3, 1)

        return out
    def forward(self, phi):
        angle = phi.norm()
        I = phi.new_empty((3, 3))
        torch.nn.init.eye_(I)

        if angle < 1e-8:
            R = I + SO3.wedge(phi)
            self.save_for_backward(phi, R)
            return R

        axis = phi / angle
        s = torch.sin(angle)
        c = torch.cos(angle)

        outer_prod_axis = axis.view(3, 1).mm(axis.view(1, 3))
        R = c * I + (1. - c) * outer_prod_axis + s * SO3.wedge(axis)

        self.save_for_backward(phi, R)
        return R
Exemple #3
0
def test_wedge_vee_batch():
    phis = torch.Tensor([[1, 2, 3], [4, 5, 6]])
    Phis = SO3.wedge(phis)
    assert (phis == SO3.vee(Phis)).all()
Exemple #4
0
def test_wedge_vee():
    phi = torch.Tensor([1, 2, 3])
    Phi = SO3.wedge(phi)
    assert (phi == SO3.vee(Phi)).all()
Exemple #5
0
def test_wedge_batch():
    phis = torch.Tensor([[1, 2, 3], [4, 5, 6]])
    Phis = SO3.wedge(phis)
    assert (Phis[0, :, :] == SO3.wedge(phis[0])).all()
    assert (Phis[1, :, :] == SO3.wedge(phis[1])).all()
Exemple #6
0
def test_wedge():
    phi = torch.Tensor([1, 2, 3])
    Phi = SO3.wedge(phi)
    assert (Phi == -Phi.t()).all()