def backward(self, grad_output): phi, R = self.saved_tensors grad = grad_output.new_empty((3, 3, 3)) e_0 = grad_output.new_tensor([1, 0, 0]).view(3, 1) e_1 = grad_output.new_tensor([0, 1, 0]).view(3, 1) e_2 = grad_output.new_tensor([0, 0, 1]).view(3, 1) I = grad_output.new_empty((3, 3)) torch.nn.init.eye_(I) if phi.norm() < 1e-8: grad[0, :, :] = SO3.wedge(e_0) grad[1, :, :] = SO3.wedge(e_1) grad[2, :, :] = SO3.wedge(e_2) else: fact = 1. / (phi.norm()**2) phi_wedge = SO3.wedge(phi) ImR = (I - R) grad[0, :, :] = fact * (phi[0] * phi_wedge + SO3.wedge(phi_wedge.mm(ImR.mm(e_0)))).mm(R) grad[1, :, :] = fact * (phi[1] * phi_wedge + SO3.wedge(phi_wedge.mm(ImR.mm(e_1)))).mm(R) grad[2, :, :] = fact * (phi[2] * phi_wedge + SO3.wedge(phi_wedge.mm(ImR.mm(e_2)))).mm(R) out = (grad_output * grad).sum((1, 2)).view(3, 1) return out
def forward(self, phi): angle = phi.norm() I = phi.new_empty((3, 3)) torch.nn.init.eye_(I) if angle < 1e-8: R = I + SO3.wedge(phi) self.save_for_backward(phi, R) return R axis = phi / angle s = torch.sin(angle) c = torch.cos(angle) outer_prod_axis = axis.view(3, 1).mm(axis.view(1, 3)) R = c * I + (1. - c) * outer_prod_axis + s * SO3.wedge(axis) self.save_for_backward(phi, R) return R
def test_wedge_vee_batch(): phis = torch.Tensor([[1, 2, 3], [4, 5, 6]]) Phis = SO3.wedge(phis) assert (phis == SO3.vee(Phis)).all()
def test_wedge_vee(): phi = torch.Tensor([1, 2, 3]) Phi = SO3.wedge(phi) assert (phi == SO3.vee(Phi)).all()
def test_wedge_batch(): phis = torch.Tensor([[1, 2, 3], [4, 5, 6]]) Phis = SO3.wedge(phis) assert (Phis[0, :, :] == SO3.wedge(phis[0])).all() assert (Phis[1, :, :] == SO3.wedge(phis[1])).all()
def test_wedge(): phi = torch.Tensor([1, 2, 3]) Phi = SO3.wedge(phi) assert (Phi == -Phi.t()).all()