示例#1
0
def test_matches_wildcards():
    sg = ShapeGuard()
    z = torch.ones([1, 2, 4, 8])
    assert sg.matches(z, "1, 2, 4, *")
    assert sg.matches(z, "*, *, *, 8")
    assert not sg.matches(z, "*")
    assert not sg.matches(z, "*, *, *")
示例#2
0
def test_matches_basic_numerical():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    assert sg.matches(a, "1, 2, 3")
    assert sg.matches(a, "1, 2.0, 3.0")
    with pytest.raises(ShapeError):
        assert sg.matches(a, "1, 2.0, 3.1")

    assert not sg.matches(a, "1, 2, 4")
    assert not sg.matches(a, "1, 2, 3, 4")
    assert not sg.matches(a, "1, 2")
示例#3
0
def test_matches_named_dims():
    sg = ShapeGuard(dims={"N": 24, "Z": 16})
    z = torch.ones([24, 16])
    assert sg.matches(z, "N, Z")
    assert sg.matches(z, "24, Z")
    assert not sg.matches(z, "N, N")
示例#4
0
def test_matches_ignores_spaces():
    sg = ShapeGuard()
    a = torch.ones([1, 2, 3])
    assert sg.matches(a, "1,2,3")
    assert sg.matches(a, "1 ,  2, 3   ")
    assert sg.matches(a, "1,  2,3 ")