示例#1
0
def test_dot_batch():
    C1 = SO3(torch.Tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]]).expand(5, 3, 3))
    C3 = SO3(torch.Tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]]))
    pt1 = torch.Tensor([1, 2, 3])
    pt3 = torch.Tensor([4, 5, 6])
    pt3 = torch.Tensor([7, 8, 9])
    pts = torch.cat(
        [pt1.unsqueeze(dim=0),
         pt3.unsqueeze(dim=0),
         pt3.unsqueeze(dim=0)],
        dim=0)  # 3x3
    ptsbatch = pts.unsqueeze(dim=0).expand(5, 3, 3)

    C1C1 = torch.bmm(C1.mat, C1.mat)
    C1C1_SO3 = C1.dot(C1).mat
    assert C1C1_SO3.shape == C1.mat.shape and utils.allclose(C1C1_SO3, C1C1)

    C1C3 = torch.matmul(C1.mat, C3.mat)
    C1C3_SO3 = C1.dot(C3).mat
    assert C1C3_SO3.shape == C1.mat.shape and utils.allclose(C1C3_SO3, C1C3)

    C1pt1 = torch.matmul(C1.mat, pt1)
    C1pt1_SO3 = C1.dot(pt1)
    assert C1pt1_SO3.shape == (C1.mat.shape[0], pt1.shape[0]) \
        and utils.allclose(C1pt1_SO3, C1pt1)

    C1pt3 = torch.matmul(C1.mat, pt3)
    C1pt3_SO3 = C1.dot(pt3)
    assert C1pt3_SO3.shape == (C1.mat.shape[0], pt3.shape[0]) \
        and utils.allclose(C1pt3_SO3, C1pt3)

    C1pts = torch.matmul(C1.mat, pts.transpose(1, 0)).transpose(2, 1)
    C1pts_SO3 = C1.dot(pts)
    assert C1pts_SO3.shape == (C1.mat.shape[0], pts.shape[0], pts.shape[1]) \
        and utils.allclose(C1pts_SO3, C1pts) \
        and utils.allclose(C1pt1, C1pts[:, 0, :]) \
        and utils.allclose(C1pt3, C1pts[:, 1, :])

    C1ptsbatch = torch.bmm(C1.mat, ptsbatch.transpose(2, 1)).transpose(2, 1)
    C1ptsbatch_SO3 = C1.dot(ptsbatch)
    assert C1ptsbatch_SO3.shape == ptsbatch.shape \
        and utils.allclose(C1ptsbatch_SO3, C1ptsbatch) \
        and utils.allclose(C1pt1, C1ptsbatch[:, 0, :]) \
        and utils.allclose(C1pt3, C1ptsbatch[:, 1, :])

    C3ptsbatch = torch.matmul(C3.mat, ptsbatch.transpose(2, 1)).transpose(2, 1)
    C3ptsbatch_SO3 = C3.dot(ptsbatch)
    assert C3ptsbatch_SO3.shape == ptsbatch.shape \
        and utils.allclose(C3ptsbatch_SO3, C3ptsbatch) \
        and utils.allclose(C3.dot(pt1), C3ptsbatch[:, 0, :]) \
        and utils.allclose(C3.dot(pt3), C3ptsbatch[:, 1, :])
示例#2
0
def test_dot():
    C = SO3(torch.Tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]]))
    pt = torch.Tensor([1, 2, 3])

    CC = C.mat.mm(C.mat)
    assert utils.allclose(C.dot(C).mat, CC)

    Cpt = C.mat.matmul(pt)
    assert utils.allclose(C.dot(pt), Cpt)