Exemple #1
0
def test_matches_wildcards_tensorflow():
    tg = TensorGuard()
    z = tf.ones([1, 2, 4, 8])
    assert tg.matches(z, "1, 2, 4, *")
    assert tg.matches(z, "*, *, *, 8")
    assert not tg.matches(z, "*")
    assert not tg.matches(z, "*, *, *")
Exemple #2
0
def test_matches_wildcards_pytorch():
    tg = TensorGuard()
    z = torch.ones([1, 2, 4, 8])
    assert tg.matches(z, "1, 2, 4, *")
    assert tg.matches(z, "*, *, *, 8")
    assert not tg.matches(z, "*")
    assert not tg.matches(z, "*, *, *")
Exemple #3
0
def test_matches_wildcards_numpy():
    tg = TensorGuard()
    z = np.ones([1, 2, 4, 8])
    assert tg.matches(z, "1, 2, 4, *")
    assert tg.matches(z, "*, *, *, 8")
    assert not tg.matches(z, "*")
    assert not tg.matches(z, "*, *, *")
Exemple #4
0
def test_matches_basic_numerical_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    assert tg.matches(a, "1, 2, 3")
    assert not tg.matches(a, "1, 2, 4")
    assert not tg.matches(a, "1, 2, 3, 4")
    assert not tg.matches(a, "1, 2")
Exemple #5
0
def reset():
    """
    Reset global tensorguard
    """
    global __tg
    __tg = TensorGuard()
Exemple #6
0
from typing import Optional, List, Any, Union, Dict

from tensorguard import tools
from tensorguard.exception import ShapeError
from tensorguard.guard import TensorGuard

__version__ = "1.0.0"

__author__ = "Michele De Vita"
__author_email__ = "*****@*****.**"

__url__ = "https://github.com/Michedev/shapeguard"

from tensorguard.tools import ShapedTensor

__tg = TensorGuard()


def reset():
    """
    Reset global tensorguard
    """
    global __tg
    __tg = TensorGuard()


def matches(tensor: Union[ShapedTensor, List[int]], template: str) -> bool:
    """
    Return True if tensor shape matches template
    """
    return tools.matches(tensor, template, __tg.dims)
Exemple #7
0
def test_matches_named_dims_pytorch():
    tg = TensorGuard(dims={"N": 24, "Z": 16})
    z = torch.ones([24, 16])
    assert tg.matches(z, "N, Z")
    assert tg.matches(z, "24, Z")
    assert not tg.matches(z, "N, N")
Exemple #8
0
def test_matches_ignores_spaces_pytorch():
    tg = TensorGuard()
    a = torch.ones([1, 2, 3])
    assert tg.matches(a, "1,2,3")
    assert tg.matches(a, "1 ,  2, 3   ")
    assert tg.matches(a, "1,  2,3 ")
Exemple #9
0
def test_matches_ignores_spaces_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    assert tg.matches(a, "1,2,3")
    assert tg.matches(a, "1 ,  2, 3   ")
    assert tg.matches(a, "1,  2,3 ")
Exemple #10
0
def test_matches_ignores_spaces_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    assert tg.matches(a, "1,2,3")
    assert tg.matches(a, "1 ,  2, 3   ")
    assert tg.matches(a, "1,  2,3 ")