Пример #1
0
def test_dot_batch():
    C1 = SO2(torch.Tensor([[0, -1],
                           [1, 0]]).expand(5, 2, 2))
    C2 = SO2(torch.Tensor([[-1, 0],
                           [0, -1]]))
    pt1 = torch.Tensor([1, 2])
    pt2 = torch.Tensor([4, 5])
    pt3 = torch.Tensor([7, 8])
    pts = torch.cat([pt1.unsqueeze(dim=0),
                     pt2.unsqueeze(dim=0),
                     pt3.unsqueeze(dim=0)], dim=0)  # 3x2
    ptsbatch = pts.unsqueeze(dim=0).expand(5, 3, 2)

    C1C1 = torch.bmm(C1.mat, C1.mat)
    C1C1_SO2 = C1.dot(C1).mat
    assert C1C1_SO2.shape == C1.mat.shape and utils.allclose(C1C1_SO2, C1C1)

    C1C2 = torch.matmul(C1.mat, C2.mat)
    C1C2_SO2 = C1.dot(C2).mat
    assert C1C2_SO2.shape == C1.mat.shape and utils.allclose(C1C2_SO2, C1C2)

    C1pt1 = torch.matmul(C1.mat, pt1)
    C1pt1_SO2 = C1.dot(pt1)
    assert C1pt1_SO2.shape == (C1.mat.shape[0], pt1.shape[0]) \
        and utils.allclose(C1pt1_SO2, C1pt1)

    C1pt2 = torch.matmul(C1.mat, pt2)
    C1pt2_SO2 = C1.dot(pt2)
    assert C1pt2_SO2.shape == (C1.mat.shape[0], pt2.shape[0]) \
        and utils.allclose(C1pt2_SO2, C1pt2)

    C1pts = torch.matmul(C1.mat, pts.transpose(1, 0)).transpose(2, 1)
    C1pts_SO2 = C1.dot(pts)
    assert C1pts_SO2.shape == (C1.mat.shape[0], pts.shape[0], pts.shape[1]) \
        and utils.allclose(C1pts_SO2, C1pts) \
        and utils.allclose(C1pt1, C1pts[:, 0, :]) \
        and utils.allclose(C1pt2, C1pts[:, 1, :])

    C1ptsbatch = torch.bmm(C1.mat, ptsbatch.transpose(2, 1)).transpose(2, 1)
    C1ptsbatch_SO2 = C1.dot(ptsbatch)
    assert C1ptsbatch_SO2.shape == ptsbatch.shape \
        and utils.allclose(C1ptsbatch_SO2, C1ptsbatch) \
        and utils.allclose(C1pt1, C1ptsbatch[:, 0, :]) \
        and utils.allclose(C1pt2, C1ptsbatch[:, 1, :])

    C2ptsbatch = torch.matmul(C2.mat, ptsbatch.transpose(2, 1)).transpose(2, 1)
    C2ptsbatch_SO2 = C2.dot(ptsbatch)
    assert C2ptsbatch_SO2.shape == ptsbatch.shape \
        and utils.allclose(C2ptsbatch_SO2, C2ptsbatch) \
        and utils.allclose(C2.dot(pt1), C2ptsbatch[:, 0, :]) \
        and utils.allclose(C2.dot(pt2), C2ptsbatch[:, 1, :])
Пример #2
0
def test_dot():
    C = SO2(torch.Tensor([[0, -1],
                          [1, 0]]))
    pt = torch.Tensor([1, 2])

    CC = C.mat.mm(C.mat)
    assert utils.allclose(C.dot(C).mat, CC)

    Cpt = C.mat.matmul(pt)
    assert utils.allclose(C.dot(pt), Cpt)
Пример #3
0
def test_dot():
    C = SO2(torch.Tensor([[0, -1], [1, 0]]))
    pt = torch.Tensor([1, 2])

    assert torch.allclose(torch.tensor([-2, 1], dtype=torch.float32),
                          C.dot(pt))