Example #1
0
def get_jaxified_logp(model: Model) -> Callable:
    """Compile model.logpt into an optimized jax function"""

    logpt = replace_shared_variables([model.logpt])[0]

    logpt_fgraph = FunctionGraph(outputs=[logpt], clone=False)
    optimize_graph(logpt_fgraph,
                   include=["fast_run"],
                   exclude=["cxx_only", "BlasOpt"])

    # We now jaxify the optimized fgraph
    logp_fn = jax_funcify(logpt_fgraph)

    if isinstance(logp_fn, (list, tuple)):
        # This handles the new JAX backend, which always returns a tuple
        logp_fn = logp_fn[0]

    def logp_fn_wrap(x):
        res = logp_fn(*x)

        if isinstance(res, (list, tuple)):
            # This handles the new JAX backend, which always returns a tuple
            res = res[0]

        # Jax expects a potential with the opposite sign of model.logpt
        return -res

    return logp_fn_wrap
Example #2
0
def _get_log_likelihood(model, samples):
    "Compute log-likelihood for all observations"
    data = {}
    for v in model.observed_RVs:
        logp_v = replace_shared_variables([logpt(v)])
        fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
        optimize_graph(fgraph,
                       include=["fast_run"],
                       exclude=["cxx_only", "BlasOpt"])
        jax_fn = jax_funcify(fgraph)
        result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]
        data[v.name] = result
    return data
Example #3
0
    def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: bool) -> int:
        """Compute ar_order given inputs

        If ar_order is not specified we do constant folding on the shape of rhos
        to retrieve it. For example, this will detect that
        Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.

        Raises
        ------
        ValueError
            If inferred ar_order cannot be inferred from rhos or if it is less than 1
        """
        if ar_order is None:
            shape_fg = FunctionGraph(
                outputs=[rhos.shape[-1]],
                features=[ShapeFeature()],
                clone=True,
            )
            (folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
            folded_shape = getattr(folded_shape, "data", None)
            if folded_shape is None:
                raise ValueError(
                    "Could not infer ar_order from last dimension of rho. Pass it "
                    "explictily or make sure rho have a static shape"
                )
            ar_order = int(folded_shape) - int(constant)
            if ar_order < 1:
                raise ValueError(
                    "Inferred ar_order is smaller than 1. Increase the last dimension "
                    "of rho or remove constant_term"
                )

        return ar_order
Example #4
0
def sample_numpyro_nuts(
    draws=1000,
    tune=1000,
    chains=4,
    target_accept=0.8,
    random_seed=10,
    model=None,
    var_names=None,
    progress_bar=True,
    keep_untransformed=False,
    chain_method="parallel",
):
    from numpyro.infer import MCMC, NUTS

    model = modelcontext(model)

    if var_names is None:
        var_names = model.unobserved_value_vars

    vars_to_sample = list(
        get_default_varnames(var_names,
                             include_transformed=keep_untransformed))

    coords = {
        cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
        for cname, cvals in model.coords.items() if cvals is not None
    }

    if hasattr(model, "RV_dims"):
        dims = {
            var_name: [dim for dim in dims if dim is not None]
            for var_name, dims in model.RV_dims.items()
        }
    else:
        dims = {}

    tic1 = pd.Timestamp.now()
    print("Compiling...", file=sys.stdout)

    rv_names = [rv.name for rv in model.value_vars]
    init_state = [model.initial_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

    logp_fn = get_jaxified_logp(model)

    nuts_kernel = NUTS(
        potential_fn=logp_fn,
        target_accept_prob=target_accept,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        dense_mass=False,
    )

    pmap_numpyro = MCMC(
        nuts_kernel,
        num_warmup=tune,
        num_samples=draws,
        num_chains=chains,
        postprocess_fn=None,
        chain_method=chain_method,
        progress_bar=progress_bar,
    )

    tic2 = pd.Timestamp.now()
    print("Compilation time = ", tic2 - tic1, file=sys.stdout)

    print("Sampling...", file=sys.stdout)

    seed = jax.random.PRNGKey(random_seed)
    map_seed = jax.random.split(seed, chains)

    if chains == 1:
        init_params = init_state
        map_seed = seed
    else:
        init_params = init_state_batched

    pmap_numpyro.run(
        map_seed,
        init_params=init_params,
        extra_fields=(
            "num_steps",
            "potential_energy",
            "energy",
            "adapt_state.step_size",
            "accept_prob",
            "diverging",
        ),
    )

    raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

    tic3 = pd.Timestamp.now()
    print("Sampling time = ", tic3 - tic2, file=sys.stdout)

    print("Transforming variables...", file=sys.stdout)
    mcmc_samples = {}
    for v in vars_to_sample:
        fgraph = FunctionGraph(model.value_vars, [v], clone=False)
        optimize_graph(fgraph,
                       include=["fast_run"],
                       exclude=["cxx_only", "BlasOpt"])
        jax_fn = jax_funcify(fgraph)
        result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
        mcmc_samples[v.name] = result

    tic4 = pd.Timestamp.now()
    print("Transformation time = ", tic4 - tic3, file=sys.stdout)

    posterior = mcmc_samples
    az_trace = az.from_dict(
        posterior=posterior,
        log_likelihood=_get_log_likelihood(model, raw_mcmc_samples),
        observed_data=find_observations(model),
        sample_stats=_sample_stats_to_xarray(pmap_numpyro),
        coords=coords,
        dims=dims,
    )

    return az_trace