コード例 #1
0
def test_jax_shape_ops():
    x_np = np.zeros((20, 3))
    x = Shape()(aet.as_tensor_variable(x_np))
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [], must_be_device_array=False)

    x = Shape_i(1)(aet.as_tensor_variable(x_np))
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [], must_be_device_array=False)
コード例 #2
0
def test_jax_shape_ops():
    x_np = np.zeros((20, 3))
    x = Shape()(aet.as_tensor_variable(x_np))
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [], must_be_device_array=False)

    x = Shape_i(1)(aet.as_tensor_variable(x_np))
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [], must_be_device_array=False)

    x = SpecifyShape()(aet.as_tensor_variable(x_np), (20, 3))
    x_fg = FunctionGraph([], [x])

    compare_jax_and_py(x_fg, [])

    with config.change_flags(compute_test_value="off"):
        x = SpecifyShape()(aet.as_tensor_variable(x_np), (2, 3))
        x_fg = FunctionGraph([], [x])

        with pytest.raises(AssertionError):
            compare_jax_and_py(x_fg, [])