Example #1
0
def test_jax_BatchedDot():
    # tensor3 . tensor3
    a = tensor3("a")
    a.tag.test_value = (np.linspace(-1, 1,
                                    10 * 5 * 3).astype(config.floatX).reshape(
                                        (10, 5, 3)))
    b = tensor3("b")
    b.tag.test_value = (np.linspace(1, -1,
                                    10 * 3 * 2).astype(config.floatX).reshape(
                                        (10, 3, 2)))
    out = aet_blas.BatchedDot()(a, b)
    fgraph = FunctionGraph([a, b], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

    # A dimension mismatch should raise a TypeError for compatibility
    inputs = [get_test_value(a)[:-1], get_test_value(b)]
    opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
    jax_mode = Mode(JAXLinker(), opts)
    aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
    with pytest.raises(TypeError):
        aesara_jax_fn(*inputs)

    # matrix . matrix
    a = matrix("a")
    a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape(
        (5, 3))
    b = matrix("b")
    b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape(
        (5, 3))
    out = aet_blas.BatchedDot()(a, b)
    fgraph = FunctionGraph([a, b], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Example #2
0
def compare_jax_and_py(
    fgraph,
    inputs,
    assert_fn=None,
    must_be_device_array=True,
):
    """Function to compare python graph output and jax compiled output for testing equality

    In the tests below computational graphs are defined in Aesara. These graphs are then passed to
    this function which then compiles the graphs in both jax and python, runs the calculation
    in both and checks if the results are the same

    Parameters
    ----------
    fgraph: FunctionGraph
        Aesara function Graph object
    inputs: iter
        Inputs for function graph
    assert_fn: func, opt
        Assert function used to check for equality between python and jax. If not
        provided uses np.testing.assert_allclose
    must_be_device_array: Bool
        Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
        if this device array is found it indicates if the result was computed by jax

    Returns
    -------
    jax_res

    """
    if assert_fn is None:
        assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)

    opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
    jax_mode = Mode(JAXLinker(), opts)
    py_mode = Mode("py", opts)

    aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
    jax_res = aesara_jax_fn(*inputs)

    if must_be_device_array:
        if isinstance(jax_res, list):
            assert all(
                isinstance(res, jax.interpreters.xla.DeviceArray)
                for res in jax_res)
        else:
            assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)

    aesara_py_fn = function(fgraph.inputs, fgraph.outputs, mode=py_mode)
    py_res = aesara_py_fn(*inputs)

    if len(fgraph.outputs) > 1:
        for j, p in zip(jax_res, py_res):
            assert_fn(j, p)
    else:
        assert_fn(jax_res, py_res)

    return jax_res
Example #3
0
    def test_inplace(self):
        """Make sure that in-place optimizations are *not* performed on the output of a ``BroadcastTo``."""
        a = aet.zeros((5, ))
        d = aet.vector("d")
        c = aet.set_subtensor(a[np.r_[0, 1, 3]], d)
        b = broadcast_to(c, (5, ))
        q = b[np.r_[0, 1, 3]]
        e = aet.set_subtensor(q, np.r_[0, 0, 0])

        opts = Query(include=["inplace"])
        py_mode = Mode("py", opts)
        e_fn = function([d], e, mode=py_mode)

        advincsub_node = e_fn.maker.fgraph.outputs[0].owner
        assert isinstance(advincsub_node.op, AdvancedIncSubtensor1)
        assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo)

        assert advincsub_node.op.inplace is False
Example #4
0

def register_linker(name, linker):
    """Add a `Linker` which can be referred to by `name` in `Mode`."""
    if name in predefined_linkers:
        raise ValueError(f"Linker name already taken: {name}")
    predefined_linkers[name] = linker


# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
    exclude = ["cxx_only"]
OPT_NONE = Query(include=[], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
OPT_MERGE = Query(include=["merge"], exclude=exclude)
OPT_FAST_RUN = Query(include=["fast_run"], exclude=exclude)
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
# We need fast_compile_gpu here.  As on the GPU, we don't have all
# operation that exist in fast_compile, but have some that get
# introduced in fast_run, we want those optimization to also run in
# fast_compile+gpu. We can't tag them just as 'gpu', as this would
# exclude them if we exclude 'gpu'.
OPT_FAST_COMPILE = Query(include=["fast_compile", "fast_compile_gpu"],
                         exclude=exclude)
OPT_STABILIZE = Query(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE"
Example #5
0
    def _set_row_mappings(self, Gamma, dir_priors, model):
        """Create maps from Dirichlet priors parameters to rows and slices in the transition matrix.

        These maps are needed when a transition matrix isn't simply comprised
        of Dirichlet prior rows, but--instead--slices of Dirichlet priors.

        Consider the following:

        .. code-block:: python

            with pm.Model():
                d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
                d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

                p_0_rv = tt.as_tensor([0, 0, 1])
                p_1_rv = tt.zeros(3)
                p_1_rv = tt.set_subtensor(p_0_rv[[0, 2]], d_0_rv)
                p_2_rv = tt.zeros(3)
                p_2_rv = tt.set_subtensor(p_1_rv[[1, 2]], d_1_rv)

                P_tt = tt.stack([p_0_rv, p_1_rv, p_2_rv])

        The transition matrix `P_tt` has Dirichlet priors in only two of its
        three rows, and--even then--they're only present in parts of two rows.

        In this example, we need to know that Dirichlet prior 0, i.e. `d_0_rv`,
        is mapped to row 1, and prior 1 is mapped to row 2.  Furthermore, we
        need to know that prior 0 fills columns 0 and 2 in row 1, and prior 1
        fills columns 1 and 2 in row 2.

        These mappings allow one to embed Dirichlet priors in larger transition
        matrices with--for instance--fixed transition behavior.

        """  # noqa: E501

        # Remove unimportant `Op`s from the transition matrix graph
        Gamma = pre_greedy_local_optimizer(
            FunctionGraph([], []),
            [
                OpRemove(Elemwise(aes.Cast(aes.float32))),
                OpRemove(Elemwise(aes.Cast(aes.float64))),
                OpRemove(Elemwise(aes.identity)),
            ],
            Gamma,
        )

        # Canonicalize the transition matrix graph
        fg = FunctionGraph(
            list(graph_inputs([Gamma] + self.dir_priors_untrans)),
            [Gamma] + self.dir_priors_untrans,
            clone=True,
        )
        canonicalize_opt = optdb.query(Query(include=["canonicalize"]))
        canonicalize_opt.optimize(fg)
        Gamma = fg.outputs[0]
        dir_priors_untrans = fg.outputs[1:]
        fg.disown()

        Gamma_DimShuffle = Gamma.owner

        if not (isinstance(Gamma_DimShuffle.op, DimShuffle)):
            raise TypeError("The transition matrix should be non-time-varying")

        Gamma_Join = Gamma_DimShuffle.inputs[0].owner

        if not (isinstance(Gamma_Join.op, at.basic.Join)):
            raise TypeError(
                "The transition matrix should be comprised of stacked row vectors"
            )

        Gamma_rows = Gamma_Join.inputs[1:]

        self.n_rows = len(Gamma_rows)

        # Loop through the rows in the transition matrix's graph and determine
        # how our transformed Dirichlet RVs map to this transition matrix.
        self.row_remaps = {}
        self.row_slices = {}
        for i, dim_row in enumerate(Gamma_rows):
            if not dim_row.owner:
                continue

            # By-pass the `DimShuffle`s applied to the `AdvancedIncSubtensor1`
            # `Op`s in which we're actually interested
            gamma_row = dim_row.owner.inputs[0]

            if gamma_row in dir_priors_untrans:
                # This is a row that's simply a `Dirichlet`
                j = dir_priors_untrans.index(gamma_row)
                self.row_remaps[j] = i
                self.row_slices[j] = slice(None)

            if gamma_row.owner.inputs[1] not in dir_priors_untrans:
                continue

            # Parts of a row set by a `*Subtensor*` `Op` using a full
            # `Dirichlet` e.g. `P_row[idx] = dir_rv`
            j = dir_priors_untrans.index(gamma_row.owner.inputs[1])
            untrans_dirich = dir_priors_untrans[j]

            if (gamma_row.owner
                    and isinstance(gamma_row.owner.op, AdvancedIncSubtensor1)
                    and gamma_row.owner.inputs[1] == untrans_dirich):
                self.row_remaps[j] = i

                rhand_val = gamma_row.owner.inputs[2]
                if not isinstance(rhand_val, TensorConstant):
                    # TODO: We could allow more types of `idx` (e.g. slices)
                    # Currently, `idx` can't be something like `2:5`
                    raise TypeError("Only array indexing allowed for mixed"
                                    " Dirichlet/non-Dirichlet rows")
                self.row_slices[j] = rhand_val.data
Example #6
0
    multivariate_normal,
    normal,
    poisson,
    uniform,
)
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import (
    lift_rv_shapes,
    local_dimshuffle_rv_lift,
    local_subtensor_rv_lift,
)
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import iscalar, vector


inplace_mode = Mode("py", Query(include=["random_make_inplace"], exclude=[]))
no_mode = Mode("py", Query(include=[], exclude=[]))


def test_inplace_optimization():

    out = normal(0, 1)

    assert out.owner.op.inplace is False

    f = function(
        [],
        out,
        mode=inplace_mode,
    )
Example #7
0
from aesara.tensor.type import (
    dscalar,
    dvector,
    iscalar,
    ivector,
    lscalar,
    matrix,
    scalar,
    tensor,
    tensor3,
    vector,
)

jax = pytest.importorskip("jax")

opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)


@pytest.fixture(scope="module", autouse=True)
def set_aesara_flags():
    with config.change_flags(cxx="", compute_test_value="ignore"):
        yield


def compare_jax_and_py(
    fgraph,
    inputs,
    assert_fn=None,
    must_be_device_array=True,
Example #8
0
def set_aesara_flags():
    opts = Query(include=[None], exclude=[])
    py_mode = Mode("py", opts)
    with config.change_flags(mode=py_mode, compute_test_value="warn"):
        yield