Ejemplo n.º 1
0
def test_guard_dynamic_shape_tensorflow():
    tg = TensorGuard()
    with pytest.raises(ShapeError):
        tg.guard([None, 2, 3], "C, B, A")

    tg.guard([None, 2, 3], "?, B, A")
    tg.guard([1, 2, 3], "C?, B, A")
    tg.guard([None, 2, 3], "C?, B, A")
Ejemplo n.º 2
0
def test_guard_raises_inferred_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    b = tf.ones([3, 2, 5])
    tg.guard(a, "A, B, C")
    with pytest.raises(ShapeError):
        tg.guard(b, "C, B, A")
Ejemplo n.º 3
0
def test_guard_ellipsis_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3, 4, 5])
    tg.guard(a, "...")
    tg.guard(a, "..., 5")
    tg.guard(a, "..., 4, 5")
    tg.guard(a, "1, ...")
    tg.guard(a, "1, 2, ...")
    tg.guard(a, "1, 2, ..., 4, 5")
    tg.guard(a, "1, 2, 3, ..., 4, 5")

    with pytest.raises(ShapeError):
        tg.guard(a, "1, 2, 3, 4, 5, 6,...")

    with pytest.raises(ShapeError):
        tg.guard(a, "..., 1, 2, 3, 4, 5, 6")
Ejemplo n.º 4
0
def test_guard_ignores_wildcard_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    tg.guard(a, "*, *, 3")
    assert tg.dims == {}
Ejemplo n.º 5
0
def test_guard_infers_dimensions_operator_priority_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 8])
    tg.guard(a, "A, B, A+C*2+1")
    assert tg.dims == {"A": 1, "B": 2, "C": 3}
Ejemplo n.º 6
0
def test_guard_raises_complex_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "A, B, B")
Ejemplo n.º 7
0
def test_guard_ellipsis_infer_dims_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3, 4, 5])
    tg.guard(a, "A, B, ..., C")
    assert tg.dims == {"A": 1, "B": 2, "C": 5}
Ejemplo n.º 8
0
def test_guard_infers_dimensions_complex_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    tg.guard(a, "A, B*2, A+C")
    assert tg.dims == {"A": 1, "B": 1, "C": 2}
Ejemplo n.º 9
0
def test_guard_raises_complex_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "A, B, B")
Ejemplo n.º 10
0
def test_guard_ignores_wildcard_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    tg.guard(a, "*, *, 3")
    assert tg.dims == {}
Ejemplo n.º 11
0
def test_guard_raises_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "3, 2, 1")
Ejemplo n.º 12
0
def test_guard_infers_dimensions_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    tg.guard(a, "A, B, C")
    assert tg.dims == {"A": 1, "B": 2, "C": 3}
Ejemplo n.º 13
0
def test_guard_raises_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "3, 2, 1")
Ejemplo n.º 14
0
def test_guard_ignores_wildcard_pytorch():
    tg = TensorGuard()
    a = torch.ones([1, 2, 3])
    tg.guard(a, "*, *, 3")
    assert tg.dims == {}
Ejemplo n.º 15
0
def test_guard_raises_pytorch():
    tg = TensorGuard()
    a = torch.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "3, 2, 1")