Example #1
0
def get_jaxified_graph(
    inputs: Optional[List[TensorVariable]] = None,
    outputs: Optional[List[TensorVariable]] = None,
) -> List[TensorVariable]:
    """Compile an Aesara graph into an optimized JAX function"""

    graph = _replace_shared_variables(outputs)

    fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
    # We need to add a Supervisor to the fgraph to be able to run the
    # JAX sequential optimizer without warnings. We made sure there
    # are no mutable input variables, so we only need to check for
    # "destroyers". This should be automatically handled by Aesara
    # once https://github.com/aesara-devs/aesara/issues/637 is fixed.
    fgraph.attach_feature(
        Supervisor(
            input
            for input in fgraph.inputs
            if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
        )
    )
    mode.JAX.optimizer.optimize(fgraph)

    # We now jaxify the optimized fgraph
    return jax_funcify(fgraph)
Example #2
0
def get_jaxified_logp(model: Model) -> Callable:
    """Compile model.logpt into an optimized jax function"""

    logpt = replace_shared_variables([model.logpt])[0]

    logpt_fgraph = FunctionGraph(outputs=[logpt], clone=False)
    optimize_graph(logpt_fgraph,
                   include=["fast_run"],
                   exclude=["cxx_only", "BlasOpt"])

    # We now jaxify the optimized fgraph
    logp_fn = jax_funcify(logpt_fgraph)

    if isinstance(logp_fn, (list, tuple)):
        # This handles the new JAX backend, which always returns a tuple
        logp_fn = logp_fn[0]

    def logp_fn_wrap(x):
        res = logp_fn(*x)

        if isinstance(res, (list, tuple)):
            # This handles the new JAX backend, which always returns a tuple
            res = res[0]

        # Jax expects a potential with the opposite sign of model.logpt
        return -res

    return logp_fn_wrap
Example #3
0
def _get_log_likelihood(model, samples):
    "Compute log-likelihood for all observations"
    data = {}
    for v in model.observed_RVs:
        logp_v = replace_shared_variables([logpt(v)])
        fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
        optimize_graph(fgraph,
                       include=["fast_run"],
                       exclude=["cxx_only", "BlasOpt"])
        jax_fn = jax_funcify(fgraph)
        result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]
        data[v.name] = result
    return data
Example #4
0
def test_jax_FunctionGraph_once():
    """Make sure that an output is only computed once when it's referenced multiple times."""
    from aesara.link.jax.dispatch import jax_funcify

    x = vector("x")
    y = vector("y")

    class TestOp(Op):
        def __init__(self):
            self.called = 0

        def make_node(self, *args):
            return Apply(self, list(args), [x.type() for x in args])

        def perform(self, inputs, outputs):
            for i, inp in enumerate(inputs):
                outputs[i][0] = inp[0]

    @jax_funcify.register(TestOp)
    def jax_funcify_TestOp(op, **kwargs):
        def func(*args, op=op):
            op.called += 1
            return list(args)

        return func

    op1 = TestOp()
    op2 = TestOp()

    q, r = op1(x, y)
    outs = op2(q + r, q + r)

    out_fg = FunctionGraph([x, y], outs, clone=False)
    assert len(out_fg.outputs) == 2

    out_jx = jax_funcify(out_fg)

    x_val = np.r_[1, 2].astype(config.floatX)
    y_val = np.r_[2, 3].astype(config.floatX)

    res = out_jx(x_val, y_val)
    assert len(res) == 2
    assert op1.called == 1
    assert op2.called == 1

    res = out_jx(x_val, y_val)
    assert len(res) == 2
    assert op1.called == 2
    assert op2.called == 2
Example #5
0
def test_jax_FunctionGraph_names():
    import inspect

    from aesara.link.jax.dispatch import jax_funcify

    x = scalar("1x")
    y = scalar("_")
    z = scalar()
    q = scalar("def")

    out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False)
    out_jx = jax_funcify(out_fg)
    sig = inspect.signature(out_jx)
    assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys())
    assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
Example #6
0
    def fgraph_convert(self, fgraph, **kwargs):
        from aesara.link.jax.dispatch import jax_funcify

        return jax_funcify(fgraph, **kwargs)
Example #7
0
def sample_numpyro_nuts(
    draws=1000,
    tune=1000,
    chains=4,
    target_accept=0.8,
    random_seed=10,
    model=None,
    var_names=None,
    progress_bar=True,
    keep_untransformed=False,
    chain_method="parallel",
):
    from numpyro.infer import MCMC, NUTS

    model = modelcontext(model)

    if var_names is None:
        var_names = model.unobserved_value_vars

    vars_to_sample = list(
        get_default_varnames(var_names,
                             include_transformed=keep_untransformed))

    coords = {
        cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
        for cname, cvals in model.coords.items() if cvals is not None
    }

    if hasattr(model, "RV_dims"):
        dims = {
            var_name: [dim for dim in dims if dim is not None]
            for var_name, dims in model.RV_dims.items()
        }
    else:
        dims = {}

    tic1 = pd.Timestamp.now()
    print("Compiling...", file=sys.stdout)

    rv_names = [rv.name for rv in model.value_vars]
    init_state = [model.initial_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

    logp_fn = get_jaxified_logp(model)

    nuts_kernel = NUTS(
        potential_fn=logp_fn,
        target_accept_prob=target_accept,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        dense_mass=False,
    )

    pmap_numpyro = MCMC(
        nuts_kernel,
        num_warmup=tune,
        num_samples=draws,
        num_chains=chains,
        postprocess_fn=None,
        chain_method=chain_method,
        progress_bar=progress_bar,
    )

    tic2 = pd.Timestamp.now()
    print("Compilation time = ", tic2 - tic1, file=sys.stdout)

    print("Sampling...", file=sys.stdout)

    seed = jax.random.PRNGKey(random_seed)
    map_seed = jax.random.split(seed, chains)

    if chains == 1:
        init_params = init_state
        map_seed = seed
    else:
        init_params = init_state_batched

    pmap_numpyro.run(
        map_seed,
        init_params=init_params,
        extra_fields=(
            "num_steps",
            "potential_energy",
            "energy",
            "adapt_state.step_size",
            "accept_prob",
            "diverging",
        ),
    )

    raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

    tic3 = pd.Timestamp.now()
    print("Sampling time = ", tic3 - tic2, file=sys.stdout)

    print("Transforming variables...", file=sys.stdout)
    mcmc_samples = {}
    for v in vars_to_sample:
        fgraph = FunctionGraph(model.value_vars, [v], clone=False)
        optimize_graph(fgraph,
                       include=["fast_run"],
                       exclude=["cxx_only", "BlasOpt"])
        jax_fn = jax_funcify(fgraph)
        result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
        mcmc_samples[v.name] = result

    tic4 = pd.Timestamp.now()
    print("Transformation time = ", tic4 - tic3, file=sys.stdout)

    posterior = mcmc_samples
    az_trace = az.from_dict(
        posterior=posterior,
        log_likelihood=_get_log_likelihood(model, raw_mcmc_samples),
        observed_data=find_observations(model),
        sample_stats=_sample_stats_to_xarray(pmap_numpyro),
        coords=coords,
        dims=dims,
    )

    return az_trace
Example #8
0
def jax_funcify_NumPyroNUTS(op, node, **kwargs):
    from numpyro.infer import MCMC, NUTS

    draws = op.draws
    tune = op.tune
    chains = op.chains
    target_accept = op.target_accept
    progress_bar = op.progress_bar
    seed = op.seed

    # Compile the "inner" log-likelihood function.  This will have extra shared
    # variable inputs as the last arguments
    logp_fn = jax_funcify(op.fgraph, **kwargs)

    if isinstance(logp_fn, (list, tuple)):
        # This handles the new JAX backend, which always returns a tuple
        logp_fn = logp_fn[0]

    def _sample(*inputs):

        if op.nshared > 0:
            current_state = inputs[:-op.nshared]
            shared_inputs = tuple(op.fgraph.inputs[-op.nshared:])
        else:
            current_state = inputs
            shared_inputs = ()

        def log_fn_wrap(x):
            res = logp_fn(*(
                x
                # We manually obtain the shared values and added them
                # as arguments to our compiled "inner" function
                + tuple(
                    v.get_value(borrow=True, return_internal_type=True)
                    for v in shared_inputs)))

            if isinstance(res, (list, tuple)):
                # This handles the new JAX backend, which always returns a tuple
                res = res[0]

            return -res

        nuts_kernel = NUTS(
            potential_fn=log_fn_wrap,
            target_accept_prob=target_accept,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
        )

        pmap_numpyro = MCMC(
            nuts_kernel,
            num_warmup=tune,
            num_samples=draws,
            num_chains=chains,
            postprocess_fn=None,
            chain_method="parallel",
            progress_bar=progress_bar,
        )

        pmap_numpyro.run(seed,
                         init_params=current_state,
                         extra_fields=("num_steps", ))
        samples = pmap_numpyro.get_samples(group_by_chain=True)
        leapfrogs_taken = pmap_numpyro.get_extra_fields(
            group_by_chain=True)["num_steps"]
        return tuple(samples) + (leapfrogs_taken, )

    return _sample