def test_guard_dynamic_shape_tensorflow(): tg = TensorGuard() with pytest.raises(ShapeError): tg.guard([None, 2, 3], "C, B, A") tg.guard([None, 2, 3], "?, B, A") tg.guard([1, 2, 3], "C?, B, A") tg.guard([None, 2, 3], "C?, B, A")
def test_guard_raises_inferred_tensorflow(): tg = TensorGuard() a = tf.ones([1, 2, 3]) b = tf.ones([3, 2, 5]) tg.guard(a, "A, B, C") with pytest.raises(ShapeError): tg.guard(b, "C, B, A")
def test_guard_ellipsis_tensorflow(): tg = TensorGuard() a = tf.ones([1, 2, 3, 4, 5]) tg.guard(a, "...") tg.guard(a, "..., 5") tg.guard(a, "..., 4, 5") tg.guard(a, "1, ...") tg.guard(a, "1, 2, ...") tg.guard(a, "1, 2, ..., 4, 5") tg.guard(a, "1, 2, 3, ..., 4, 5") with pytest.raises(ShapeError): tg.guard(a, "1, 2, 3, 4, 5, 6,...") with pytest.raises(ShapeError): tg.guard(a, "..., 1, 2, 3, 4, 5, 6")
def test_guard_ignores_wildcard_tensorflow(): tg = TensorGuard() a = tf.ones([1, 2, 3]) tg.guard(a, "*, *, 3") assert tg.dims == {}
def test_guard_infers_dimensions_operator_priority_tensorflow(): tg = TensorGuard() a = tf.ones([1, 2, 8]) tg.guard(a, "A, B, A+C*2+1") assert tg.dims == {"A": 1, "B": 2, "C": 3}
def test_guard_raises_complex_tensorflow(): tg = TensorGuard() a = tf.ones([1, 2, 3]) with pytest.raises(ShapeError): tg.guard(a, "A, B, B")
def test_guard_ellipsis_infer_dims_numpy(): tg = TensorGuard() a = np.ones([1, 2, 3, 4, 5]) tg.guard(a, "A, B, ..., C") assert tg.dims == {"A": 1, "B": 2, "C": 5}
def test_guard_infers_dimensions_complex_tensorflow(): tg = TensorGuard() a = tf.ones([1, 2, 3]) tg.guard(a, "A, B*2, A+C") assert tg.dims == {"A": 1, "B": 1, "C": 2}
def test_guard_raises_complex_numpy(): tg = TensorGuard() a = np.ones([1, 2, 3]) with pytest.raises(ShapeError): tg.guard(a, "A, B, B")
def test_guard_ignores_wildcard_numpy(): tg = TensorGuard() a = np.ones([1, 2, 3]) tg.guard(a, "*, *, 3") assert tg.dims == {}
def test_guard_raises_tensorflow(): tg = TensorGuard() a = tf.ones([1, 2, 3]) with pytest.raises(ShapeError): tg.guard(a, "3, 2, 1")
def test_guard_infers_dimensions_numpy(): tg = TensorGuard() a = np.ones([1, 2, 3]) tg.guard(a, "A, B, C") assert tg.dims == {"A": 1, "B": 2, "C": 3}
def test_guard_raises_numpy(): tg = TensorGuard() a = np.ones([1, 2, 3]) with pytest.raises(ShapeError): tg.guard(a, "3, 2, 1")
def test_guard_ignores_wildcard_pytorch(): tg = TensorGuard() a = torch.ones([1, 2, 3]) tg.guard(a, "*, *, 3") assert tg.dims == {}
def test_guard_raises_pytorch(): tg = TensorGuard() a = torch.ones([1, 2, 3]) with pytest.raises(ShapeError): tg.guard(a, "3, 2, 1")