Esempio n. 1
0
def test_jax_BatchedDot():
    # tensor3 . tensor3
    a = tensor3("a")
    a.tag.test_value = (np.linspace(-1, 1,
                                    10 * 5 * 3).astype(config.floatX).reshape(
                                        (10, 5, 3)))
    b = tensor3("b")
    b.tag.test_value = (np.linspace(1, -1,
                                    10 * 3 * 2).astype(config.floatX).reshape(
                                        (10, 3, 2)))
    out = aet_blas.BatchedDot()(a, b)
    fgraph = FunctionGraph([a, b], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    # A dimension mismatch should raise a TypeError for compatibility
    inputs = [get_test_value(a)[:-1], get_test_value(b)]
    opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
    jax_mode = Mode(JAXLinker(), opts)
    aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
    with pytest.raises(TypeError):
        aesara_jax_fn(*inputs)

    # matrix . matrix
    a = matrix("a")
    a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape(
        (5, 3))
    b = matrix("b")
    b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape(
        (5, 3))
    out = aet_blas.BatchedDot()(a, b)
    fgraph = FunctionGraph([a, b], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Esempio n. 2
0
def compare_jax_and_py(
    fgraph,
    inputs,
    assert_fn=None,
    must_be_device_array=True,
):
    """Function to compare python graph output and jax compiled output for testing equality

    In the tests below computational graphs are defined in Aesara. These graphs are then passed to
    this function which then compiles the graphs in both jax and python, runs the calculation
    in both and checks if the results are the same

    Parameters
    ----------
    fgraph: FunctionGraph
        Aesara function Graph object
    inputs: iter
        Inputs for function graph
    assert_fn: func, opt
        Assert function used to check for equality between python and jax. If not
        provided uses np.testing.assert_allclose
    must_be_device_array: Bool
        Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
        if this device array is found it indicates if the result was computed by jax

    Returns
    -------
    jax_res

    """
    if assert_fn is None:
        assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)

    opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
    jax_mode = Mode(JAXLinker(), opts)
    py_mode = Mode("py", opts)

    aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
    jax_res = aesara_jax_fn(*inputs)

    if must_be_device_array:
        if isinstance(jax_res, list):
            assert all(
                isinstance(res, jax.interpreters.xla.DeviceArray)
                for res in jax_res)
        else:
            assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)

    aesara_py_fn = function(fgraph.inputs, fgraph.outputs, mode=py_mode)
    py_res = aesara_py_fn(*inputs)

    if len(fgraph.outputs) > 1:
        for j, p in zip(jax_res, py_res):
            assert_fn(j, p)
    else:
        assert_fn(jax_res, py_res)

    return jax_res
Esempio n. 3
0
_logger = logging.getLogger("aesara.compile.mode")

# If a string is passed as the linker argument in the constructor for
# Mode, it will be used as the key to retrieve the real linker in this
# dictionary
predefined_linkers = {
    "py": PerformLinker(),  # Use allow_gc Aesara flag
    "c": CLinker(),  # Don't support gc. so don't check allow_gc
    "c|py": OpWiseCLinker(),  # Use allow_gc Aesara flag
    "c|py_nogc": OpWiseCLinker(allow_gc=False),
    "vm": VMLinker(use_cloop=False),  # Use allow_gc Aesara flag
    "cvm": VMLinker(use_cloop=True),  # Use allow_gc Aesara flag
    "vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
    "cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
    "jax": JAXLinker(),
}


def register_linker(name, linker):
    """Add a `Linker` which can be referred to by `name` in `Mode`."""
    if name in predefined_linkers:
        raise ValueError(f"Linker name already taken: {name}")
    predefined_linkers[name] = linker


# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
Esempio n. 4
0
    dscalar,
    dvector,
    iscalar,
    ivector,
    lscalar,
    matrix,
    scalar,
    tensor,
    tensor3,
    vector,
)

jax = pytest.importorskip("jax")

opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)


@pytest.fixture(scope="module", autouse=True)
def set_aesara_flags():
    with config.change_flags(cxx="", compute_test_value="ignore"):
        yield


def compare_jax_and_py(
    fgraph: FunctionGraph,
    test_inputs: iter,
    assert_fn: Optional[callable] = None,
    must_be_device_array: bool = True,
):