def test_matches_wildcards(): sg = ShapeGuard() z = torch.ones([1, 2, 4, 8]) assert sg.matches(z, "1, 2, 4, *") assert sg.matches(z, "*, *, *, 8") assert not sg.matches(z, "*") assert not sg.matches(z, "*, *, *")
def test_matches_basic_numerical(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) assert sg.matches(a, "1, 2, 3") assert sg.matches(a, "1, 2.0, 3.0") with pytest.raises(ShapeError): assert sg.matches(a, "1, 2.0, 3.1") assert not sg.matches(a, "1, 2, 4") assert not sg.matches(a, "1, 2, 3, 4") assert not sg.matches(a, "1, 2")
def test_matches_named_dims(): sg = ShapeGuard(dims={"N": 24, "Z": 16}) z = torch.ones([24, 16]) assert sg.matches(z, "N, Z") assert sg.matches(z, "24, Z") assert not sg.matches(z, "N, N")
def test_matches_ignores_spaces(): sg = ShapeGuard() a = torch.ones([1, 2, 3]) assert sg.matches(a, "1,2,3") assert sg.matches(a, "1 , 2, 3 ") assert sg.matches(a, "1, 2,3 ")