def test_dot_batch(): C1 = SO2(torch.Tensor([[0, -1], [1, 0]]).expand(5, 2, 2)) C2 = SO2(torch.Tensor([[-1, 0], [0, -1]])) pt1 = torch.Tensor([1, 2]) pt2 = torch.Tensor([4, 5]) pt3 = torch.Tensor([7, 8]) pts = torch.cat([pt1.unsqueeze(dim=0), pt2.unsqueeze(dim=0), pt3.unsqueeze(dim=0)], dim=0) # 3x2 ptsbatch = pts.unsqueeze(dim=0).expand(5, 3, 2) C1C1 = torch.bmm(C1.mat, C1.mat) C1C1_SO2 = C1.dot(C1).mat assert C1C1_SO2.shape == C1.mat.shape and utils.allclose(C1C1_SO2, C1C1) C1C2 = torch.matmul(C1.mat, C2.mat) C1C2_SO2 = C1.dot(C2).mat assert C1C2_SO2.shape == C1.mat.shape and utils.allclose(C1C2_SO2, C1C2) C1pt1 = torch.matmul(C1.mat, pt1) C1pt1_SO2 = C1.dot(pt1) assert C1pt1_SO2.shape == (C1.mat.shape[0], pt1.shape[0]) \ and utils.allclose(C1pt1_SO2, C1pt1) C1pt2 = torch.matmul(C1.mat, pt2) C1pt2_SO2 = C1.dot(pt2) assert C1pt2_SO2.shape == (C1.mat.shape[0], pt2.shape[0]) \ and utils.allclose(C1pt2_SO2, C1pt2) C1pts = torch.matmul(C1.mat, pts.transpose(1, 0)).transpose(2, 1) C1pts_SO2 = C1.dot(pts) assert C1pts_SO2.shape == (C1.mat.shape[0], pts.shape[0], pts.shape[1]) \ and utils.allclose(C1pts_SO2, C1pts) \ and utils.allclose(C1pt1, C1pts[:, 0, :]) \ and utils.allclose(C1pt2, C1pts[:, 1, :]) C1ptsbatch = torch.bmm(C1.mat, ptsbatch.transpose(2, 1)).transpose(2, 1) C1ptsbatch_SO2 = C1.dot(ptsbatch) assert C1ptsbatch_SO2.shape == ptsbatch.shape \ and utils.allclose(C1ptsbatch_SO2, C1ptsbatch) \ and utils.allclose(C1pt1, C1ptsbatch[:, 0, :]) \ and utils.allclose(C1pt2, C1ptsbatch[:, 1, :]) C2ptsbatch = torch.matmul(C2.mat, ptsbatch.transpose(2, 1)).transpose(2, 1) C2ptsbatch_SO2 = C2.dot(ptsbatch) assert C2ptsbatch_SO2.shape == ptsbatch.shape \ and utils.allclose(C2ptsbatch_SO2, C2ptsbatch) \ and utils.allclose(C2.dot(pt1), C2ptsbatch[:, 0, :]) \ and utils.allclose(C2.dot(pt2), C2ptsbatch[:, 1, :])
def test_dot(): C = SO2(torch.Tensor([[0, -1], [1, 0]])) pt = torch.Tensor([1, 2]) CC = C.mat.mm(C.mat) assert utils.allclose(C.dot(C).mat, CC) Cpt = C.mat.matmul(pt) assert utils.allclose(C.dot(pt), Cpt)
def test_dot(): C = SO2(torch.Tensor([[0, -1], [1, 0]])) pt = torch.Tensor([1, 2]) assert torch.allclose(torch.tensor([-2, 1], dtype=torch.float32), C.dot(pt))