def test_dot_batch(): C1 = SO3(torch.Tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]]).expand(5, 3, 3)) C3 = SO3(torch.Tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]])) pt1 = torch.Tensor([1, 2, 3]) pt3 = torch.Tensor([4, 5, 6]) pt3 = torch.Tensor([7, 8, 9]) pts = torch.cat( [pt1.unsqueeze(dim=0), pt3.unsqueeze(dim=0), pt3.unsqueeze(dim=0)], dim=0) # 3x3 ptsbatch = pts.unsqueeze(dim=0).expand(5, 3, 3) C1C1 = torch.bmm(C1.mat, C1.mat) C1C1_SO3 = C1.dot(C1).mat assert C1C1_SO3.shape == C1.mat.shape and utils.allclose(C1C1_SO3, C1C1) C1C3 = torch.matmul(C1.mat, C3.mat) C1C3_SO3 = C1.dot(C3).mat assert C1C3_SO3.shape == C1.mat.shape and utils.allclose(C1C3_SO3, C1C3) C1pt1 = torch.matmul(C1.mat, pt1) C1pt1_SO3 = C1.dot(pt1) assert C1pt1_SO3.shape == (C1.mat.shape[0], pt1.shape[0]) \ and utils.allclose(C1pt1_SO3, C1pt1) C1pt3 = torch.matmul(C1.mat, pt3) C1pt3_SO3 = C1.dot(pt3) assert C1pt3_SO3.shape == (C1.mat.shape[0], pt3.shape[0]) \ and utils.allclose(C1pt3_SO3, C1pt3) C1pts = torch.matmul(C1.mat, pts.transpose(1, 0)).transpose(2, 1) C1pts_SO3 = C1.dot(pts) assert C1pts_SO3.shape == (C1.mat.shape[0], pts.shape[0], pts.shape[1]) \ and utils.allclose(C1pts_SO3, C1pts) \ and utils.allclose(C1pt1, C1pts[:, 0, :]) \ and utils.allclose(C1pt3, C1pts[:, 1, :]) C1ptsbatch = torch.bmm(C1.mat, ptsbatch.transpose(2, 1)).transpose(2, 1) C1ptsbatch_SO3 = C1.dot(ptsbatch) assert C1ptsbatch_SO3.shape == ptsbatch.shape \ and utils.allclose(C1ptsbatch_SO3, C1ptsbatch) \ and utils.allclose(C1pt1, C1ptsbatch[:, 0, :]) \ and utils.allclose(C1pt3, C1ptsbatch[:, 1, :]) C3ptsbatch = torch.matmul(C3.mat, ptsbatch.transpose(2, 1)).transpose(2, 1) C3ptsbatch_SO3 = C3.dot(ptsbatch) assert C3ptsbatch_SO3.shape == ptsbatch.shape \ and utils.allclose(C3ptsbatch_SO3, C3ptsbatch) \ and utils.allclose(C3.dot(pt1), C3ptsbatch[:, 0, :]) \ and utils.allclose(C3.dot(pt3), C3ptsbatch[:, 1, :])
def test_dot(): C = SO3(torch.Tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]])) pt = torch.Tensor([1, 2, 3]) CC = C.mat.mm(C.mat) assert utils.allclose(C.dot(C).mat, CC) Cpt = C.mat.matmul(pt) assert utils.allclose(C.dot(pt), Cpt)