コード例 #1
0
def test_guard_dynamic_shape_tensorflow():
    tg = TensorGuard()
    with pytest.raises(ShapeError):
        tg.guard([None, 2, 3], "C, B, A")

    tg.guard([None, 2, 3], "?, B, A")
    tg.guard([1, 2, 3], "C?, B, A")
    tg.guard([None, 2, 3], "C?, B, A")
コード例 #2
0
def test_guard_raises_inferred_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    b = tf.ones([3, 2, 5])
    tg.guard(a, "A, B, C")
    with pytest.raises(ShapeError):
        tg.guard(b, "C, B, A")
コード例 #3
0
def test_guard_ellipsis_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3, 4, 5])
    tg.guard(a, "...")
    tg.guard(a, "..., 5")
    tg.guard(a, "..., 4, 5")
    tg.guard(a, "1, ...")
    tg.guard(a, "1, 2, ...")
    tg.guard(a, "1, 2, ..., 4, 5")
    tg.guard(a, "1, 2, 3, ..., 4, 5")

    with pytest.raises(ShapeError):
        tg.guard(a, "1, 2, 3, 4, 5, 6,...")

    with pytest.raises(ShapeError):
        tg.guard(a, "..., 1, 2, 3, 4, 5, 6")
コード例 #4
0
def test_guard_ignores_wildcard_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    tg.guard(a, "*, *, 3")
    assert tg.dims == {}
コード例 #5
0
def test_guard_infers_dimensions_operator_priority_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 8])
    tg.guard(a, "A, B, A+C*2+1")
    assert tg.dims == {"A": 1, "B": 2, "C": 3}
コード例 #6
0
def test_guard_raises_complex_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "A, B, B")
コード例 #7
0
def test_guard_ellipsis_infer_dims_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3, 4, 5])
    tg.guard(a, "A, B, ..., C")
    assert tg.dims == {"A": 1, "B": 2, "C": 5}
コード例 #8
0
def test_guard_infers_dimensions_complex_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    tg.guard(a, "A, B*2, A+C")
    assert tg.dims == {"A": 1, "B": 1, "C": 2}
コード例 #9
0
def test_guard_raises_complex_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "A, B, B")
コード例 #10
0
def test_guard_ignores_wildcard_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    tg.guard(a, "*, *, 3")
    assert tg.dims == {}
コード例 #11
0
def test_guard_raises_tensorflow():
    tg = TensorGuard()
    a = tf.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "3, 2, 1")
コード例 #12
0
def test_guard_infers_dimensions_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    tg.guard(a, "A, B, C")
    assert tg.dims == {"A": 1, "B": 2, "C": 3}
コード例 #13
0
def test_guard_raises_numpy():
    tg = TensorGuard()
    a = np.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "3, 2, 1")
コード例 #14
0
def test_guard_ignores_wildcard_pytorch():
    tg = TensorGuard()
    a = torch.ones([1, 2, 3])
    tg.guard(a, "*, *, 3")
    assert tg.dims == {}
コード例 #15
0
def test_guard_raises_pytorch():
    tg = TensorGuard()
    a = torch.ones([1, 2, 3])
    with pytest.raises(ShapeError):
        tg.guard(a, "3, 2, 1")