Пример #1
0
def test_odot_batch():
    p1 = torch.Tensor([1, 2])
    p2 = torch.Tensor([2, 3])
    ps = torch.cat([p1.unsqueeze(dim=0), p2.unsqueeze(dim=0)], dim=0)

    odot1 = SE2.odot(p1)
    odot2 = SE2.odot(p2)
    odots = SE2.odot(ps)

    assert (odot1 == odots[0, :, :]).all()
    assert (odot2 == odots[1, :, :]).all()
Пример #2
0
def test_odot():
    p1 = torch.Tensor([1, 2])
    p2 = torch.Tensor([1, 2, 1])
    p3 = torch.Tensor([1, 2, 0])

    odot12 = torch.cat([SE2.odot(p1), torch.zeros(3).unsqueeze_(dim=0)], dim=0)
    odot13 = torch.cat(
        [SE2.odot(p1, directional=True),
         torch.zeros(3).unsqueeze_(dim=0)],
        dim=0)
    odot2 = SE2.odot(p2)
    odot3 = SE2.odot(p3)

    assert (odot12 == odot2).all()
    assert (odot13 == odot3).all()