Esempio n. 1
0
    def test_layers(self):
        with pm.Model(rng_seeder=232093) as model:
            a = pm.Uniform("a", lower=0, upper=1, size=10)
            b = pm.Binomial("b", n=1, p=a, size=10)

        b_sampler = compile_rv_inplace([], b, mode="FAST_RUN")
        avg = np.stack([b_sampler() for i in range(10000)]).mean(0)
        npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2)
Esempio n. 2
0
def delta_logp(point, logp, vars, shared):
    [logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared)

    tensor_type = inarray0.type
    inarray1 = tensor_type("inarray1")

    logp1 = pm.CallableTensor(logp0)(inarray1)

    f = compile_rv_inplace([inarray1, inarray0], logp1 - logp0)
    f.trust_input = True
    return f
Esempio n. 3
0
def _logp_forw(point, out_vars, vars, shared):
    """Compile Aesara function of the model and the input and output variables.

    Parameters
    ----------
    out_vars: List
        containing :class:`pymc.Distribution` for the output variables
    vars: List
        containing :class:`pymc.Distribution` for the input variables
    shared: List
        containing :class:`aesara.tensor.Tensor` for depended shared data
    """

    # Convert expected input of discrete variables to (rounded) floats
    if any(var.dtype in discrete_types for var in vars):
        replace_int_to_float = {}
        replace_float_to_round = {}
        new_vars = []
        for var in vars:
            if var.dtype in discrete_types:
                float_var = at.TensorType("floatX",
                                          var.broadcastable)(var.name)
                replace_int_to_float[var] = float_var
                new_vars.append(float_var)

                round_float_var = at.round(float_var)
                round_float_var.name = var.name
                replace_float_to_round[float_var] = round_float_var
            else:
                new_vars.append(var)

        replace_int_to_float.update(shared)
        replace_float_to_round.update(shared)
        out_vars = clone_replace(out_vars, replace_int_to_float, strict=False)
        out_vars = clone_replace(out_vars, replace_float_to_round)
        vars = new_vars

    out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared)
    f = compile_rv_inplace([inarray0], out_list[0])
    f.trust_input = True
    return f
Esempio n. 4
0
def make_initial_point_fn(
    *,
    model,
    overrides: Optional[StartDict] = None,
    jitter_rvs: Optional[Set[TensorVariable]] = None,
    default_strategy: str = "prior",
    return_transformed: bool = True,
) -> Callable:
    """Create seeded function that computes initial values for all free model variables.

    Parameters
    ----------
    jitter_rvs : set
        The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be
        added to the initial value. Only available for variables that have a transform or real-valued support.
    default_strategy : str
        Which of { "moment", "prior" } to prefer if the initval setting for an RV is None.
    overrides : dict
        Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
    return_transformed : bool
        If `True` the returned variables will correspond to transformed initial values.
    """
    def find_rng_nodes(variables):
        return [
            node for node in graph_inputs(variables) if isinstance(
                node,
                (
                    at.random.var.RandomStateSharedVariable,
                    at.random.var.RandomGeneratorSharedVariable,
                ),
            )
        ]

    overrides = convert_str_to_rv_dict(model, overrides or {})

    initial_values = make_initial_point_expression(
        free_rvs=model.free_RVs,
        rvs_to_values=model.rvs_to_values,
        initval_strategies={
            **model.initial_values,
            **(overrides or {})
        },
        jitter_rvs=jitter_rvs,
        default_strategy=default_strategy,
        return_transformed=return_transformed,
    )

    # Replace original rng shared variables so that we don't mess with them
    # when calling the final seeded function
    graph = FunctionGraph(outputs=initial_values, clone=False)
    rng_nodes = find_rng_nodes(graph.outputs)
    new_rng_nodes = []
    for rng_node in rng_nodes:
        if isinstance(rng_node, at.random.var.RandomStateSharedVariable):
            new_rng = np.random.RandomState(np.random.PCG64())
        else:
            new_rng = np.random.Generator(np.random.PCG64())
        new_rng_nodes.append(aesara.shared(new_rng))
    graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True)
    func = compile_rv_inplace(inputs=[],
                              outputs=graph.outputs,
                              mode=aesara.compile.mode.FAST_COMPILE)

    varnames = []
    for var in model.free_RVs:
        transform = getattr(model.rvs_to_values[var].tag, "transform", None)
        if transform is not None and return_transformed:
            name = get_transformed_name(var.name, transform)
        else:
            name = var.name
        varnames.append(name)

    def make_seeded_function(func):

        rngs = find_rng_nodes(func.maker.fgraph.outputs)

        @functools.wraps(func)
        def inner(seed, *args, **kwargs):
            seeds = [
                np.random.PCG64(sub_seed)
                for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs))
            ]
            for rng, seed in zip(rngs, seeds):
                if isinstance(rng, at.random.var.RandomStateSharedVariable):
                    new_rng = np.random.RandomState(seed)
                else:
                    new_rng = np.random.Generator(seed)
                rng.set_value(new_rng, True)
            values = func(*args, **kwargs)
            return dict(zip(varnames, values))

        return inner

    return make_seeded_function(func)
Esempio n. 5
0
def sample_numpyro_nuts(
    draws=1000,
    tune=1000,
    chains=4,
    target_accept=0.8,
    random_seed=10,
    model=None,
    progress_bar=True,
    keep_untransformed=False,
):
    from numpyro.infer import MCMC, NUTS

    model = modelcontext(model)

    tic1 = pd.Timestamp.now()
    print("Compiling...", file=sys.stdout)

    rv_names = [rv.name for rv in model.value_vars]
    init_state = [model.initial_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

    logp_fn = get_jaxified_logp(model)

    nuts_kernel = NUTS(
        potential_fn=logp_fn,
        target_accept_prob=target_accept,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        dense_mass=False,
    )

    pmap_numpyro = MCMC(
        nuts_kernel,
        num_warmup=tune,
        num_samples=draws,
        num_chains=chains,
        postprocess_fn=None,
        chain_method="parallel",
        progress_bar=progress_bar,
    )

    tic2 = pd.Timestamp.now()
    print("Compilation time = ", tic2 - tic1, file=sys.stdout)

    print("Sampling...", file=sys.stdout)

    seed = jax.random.PRNGKey(random_seed)
    map_seed = jax.random.split(seed, chains)

    pmap_numpyro.run(map_seed,
                     init_params=init_state_batched,
                     extra_fields=("num_steps", ))
    raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

    tic3 = pd.Timestamp.now()
    print("Sampling time = ", tic3 - tic2, file=sys.stdout)

    print("Transforming variables...", file=sys.stdout)
    mcmc_samples = []
    for i, (value_var,
            raw_samples) in enumerate(zip(model.value_vars, raw_mcmc_samples)):
        raw_samples = at.constant(np.asarray(raw_samples))

        rv = model.values_to_rvs[value_var]
        transform = getattr(value_var.tag, "transform", None)

        if transform is not None:
            # TODO: This will fail when the transformation depends on another variable
            #  such as in interval transform with RVs as edges
            trans_samples = transform.backward(raw_samples, *rv.owner.inputs)
            trans_samples.name = rv.name
            mcmc_samples.append(trans_samples)

            if keep_untransformed:
                raw_samples.name = value_var.name
                mcmc_samples.append(raw_samples)
        else:
            raw_samples.name = rv.name
            mcmc_samples.append(raw_samples)

    mcmc_varnames = [var.name for var in mcmc_samples]
    mcmc_samples = compile_rv_inplace(
        [],
        mcmc_samples,
        mode="JAX",
    )()

    tic4 = pd.Timestamp.now()
    print("Transformation time = ", tic4 - tic3, file=sys.stdout)

    posterior = {k: v for k, v in zip(mcmc_varnames, mcmc_samples)}
    az_trace = az.from_dict(posterior=posterior)

    return az_trace