def test_matches_wildcards_tensorflow(): tg = TensorGuard() z = tf.ones([1, 2, 4, 8]) assert tg.matches(z, "1, 2, 4, *") assert tg.matches(z, "*, *, *, 8") assert not tg.matches(z, "*") assert not tg.matches(z, "*, *, *")
def test_matches_wildcards_pytorch(): tg = TensorGuard() z = torch.ones([1, 2, 4, 8]) assert tg.matches(z, "1, 2, 4, *") assert tg.matches(z, "*, *, *, 8") assert not tg.matches(z, "*") assert not tg.matches(z, "*, *, *")
def test_matches_wildcards_numpy(): tg = TensorGuard() z = np.ones([1, 2, 4, 8]) assert tg.matches(z, "1, 2, 4, *") assert tg.matches(z, "*, *, *, 8") assert not tg.matches(z, "*") assert not tg.matches(z, "*, *, *")
def test_matches_basic_numerical_numpy(): tg = TensorGuard() a = np.ones([1, 2, 3]) assert tg.matches(a, "1, 2, 3") assert not tg.matches(a, "1, 2, 4") assert not tg.matches(a, "1, 2, 3, 4") assert not tg.matches(a, "1, 2")
def reset(): """ Reset global tensorguard """ global __tg __tg = TensorGuard()
from typing import Optional, List, Any, Union, Dict from tensorguard import tools from tensorguard.exception import ShapeError from tensorguard.guard import TensorGuard __version__ = "1.0.0" __author__ = "Michele De Vita" __author_email__ = "*****@*****.**" __url__ = "https://github.com/Michedev/shapeguard" from tensorguard.tools import ShapedTensor __tg = TensorGuard() def reset(): """ Reset global tensorguard """ global __tg __tg = TensorGuard() def matches(tensor: Union[ShapedTensor, List[int]], template: str) -> bool: """ Return True if tensor shape matches template """ return tools.matches(tensor, template, __tg.dims)
def test_matches_named_dims_pytorch(): tg = TensorGuard(dims={"N": 24, "Z": 16}) z = torch.ones([24, 16]) assert tg.matches(z, "N, Z") assert tg.matches(z, "24, Z") assert not tg.matches(z, "N, N")
def test_matches_ignores_spaces_pytorch(): tg = TensorGuard() a = torch.ones([1, 2, 3]) assert tg.matches(a, "1,2,3") assert tg.matches(a, "1 , 2, 3 ") assert tg.matches(a, "1, 2,3 ")
def test_matches_ignores_spaces_tensorflow(): tg = TensorGuard() a = tf.ones([1, 2, 3]) assert tg.matches(a, "1,2,3") assert tg.matches(a, "1 , 2, 3 ") assert tg.matches(a, "1, 2,3 ")
def test_matches_ignores_spaces_numpy(): tg = TensorGuard() a = np.ones([1, 2, 3]) assert tg.matches(a, "1,2,3") assert tg.matches(a, "1 , 2, 3 ") assert tg.matches(a, "1, 2,3 ")