Beispiel #1
0
def _transform_samples(samples, model, keep_untransformed=False):

    # Find out which RVs we need to compute:
    free_rv_names = {x.name for x in model.free_RVs}
    unobserved_names = {x.name for x in model.unobserved_RVs}

    names_to_compute = unobserved_names - free_rv_names
    ops_to_compute = [
        x for x in model.unobserved_RVs if x.name in names_to_compute
    ]

    # Create function graph for these:
    fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)

    # Jaxify, which returns a list of functions, one for each op
    jax_fns = jax_funcify(fgraph)

    # Put together the inputs
    inputs = [samples[x.name] for x in model.free_RVs]

    for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):

        # We need a function taking a single argument to run vmap, while the
        # jax_fn takes a list, so:
        result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs)

        # Add to sample dict
        samples[cur_op.name] = result

    # Discard unwanted transformed variables, if desired:
    vars_to_keep = set(
        pm.util.get_default_varnames(list(samples.keys()),
                                     include_transformed=keep_untransformed))
    samples = {x: y for x, y in samples.items() if x in vars_to_keep}

    return samples
Beispiel #2
0
    def create_jax_thunks(self, compute_map, storage_map):
        """Create a thunk for each output of the `Linker`s `FunctionGraph`.

        This is differs from the other thunk-making function in that it only
        produces thunks for the `FunctionGraph` output nodes.

        Parameters
        ----------
        compute_map: dict
            The compute map dictionary.
        storage_map: dict
            The storage map dictionary.

        Returns
        -------
        thunks: list
            A tuple containing the thunks.
        output_nodes: list and their
            A tuple containing the output nodes.

        """
        import jax

        from theano.link.jax.jax_dispatch import jax_funcify

        output_nodes = [o.owner for o in self.fgraph.outputs]

        # Create a JAX-compilable function from our `FunctionGraph`
        jaxed_fgraph_outputs = jax_funcify(self.fgraph)

        assert len(jaxed_fgraph_outputs) == len(output_nodes)

        # I suppose we can consider `Constant`s to be "static" according to
        # JAX.
        static_argnums = [
            n for n, i in enumerate(self.fgraph.inputs)
            if isinstance(i, Constant)
        ]

        thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]

        thunks = []

        for node, jax_funcs in zip(output_nodes, jaxed_fgraph_outputs):

            thunk_outputs = [storage_map[n] for n in node.outputs]

            if not isinstance(jax_funcs, Sequence):
                jax_funcs = [jax_funcs]

            jax_impl_jits = [
                jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
            ]

            def thunk(node=node,
                      jax_impl_jits=jax_impl_jits,
                      thunk_outputs=thunk_outputs):
                outputs = [
                    jax_impl_jit(*[x[0] for x in thunk_inputs])
                    for jax_impl_jit in jax_impl_jits
                ]

                if len(jax_impl_jits) < len(node.outputs):
                    # In this case, the JAX function will output a single
                    # output that contains the other outputs.
                    # This happens for multi-output `Op`s that directly
                    # correspond to multi-output JAX functions (e.g. `SVD` and
                    # `jax.numpy.linalg.svd`).
                    outputs = outputs[0]

                for o_node, o_storage, o_val in zip(node.outputs,
                                                    thunk_outputs, outputs):
                    compute_map[o_node][0] = True
                    if len(o_storage) > 1:
                        assert len(o_storage) == len(o_val)
                        for i, o_sub_val in enumerate(o_val):
                            o_storage[i] = o_sub_val
                    else:
                        o_storage[0] = o_val
                return outputs

            thunk.inputs = thunk_inputs
            thunk.outputs = thunk_outputs
            thunk.lazy = False

            thunks.append(thunk)

        return thunks, output_nodes
Beispiel #3
0
def sample_tfp_mhrw(
    draws=1000,
    tune=5000,
    burnin=1000,
    thin=0,
    chains=4,
    random_seed=10,
    model=None,
    num_tuning_epoch=5,
    step_size=.1,
):
    import jax

    from tensorflow_probability.substrates import jax as tfp

    model = modelcontext(model)

    seed = jax.random.PRNGKey(random_seed)

    fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
    fns = jax_funcify(fgraph)
    logp_fn_jax = fns[0]

    rv_names = [rv.name for rv in model.free_RVs]
    init_state = [model.test_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

    @jax.vmap
    def _sample(init_state, seed, step_size):
        def trace_is_accepted(states, previous_kernel_results):
            return previous_kernel_results.is_accepted

        def gen_kernel(step_size):
            kernel_ = tfp.mcmc.RandomWalkMetropolis(
                target_log_prob_fn=logp_fn_jax,
                new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=step_size))
            return kernel_

        accept_rate = 0.
        for i in range(num_tuning_epoch - 1):
            print(jax.numpy.mean(accept_rate))
            print(jax.numpy.std(step_size))
            #print(f"Tuning step {i+1:2.0f} of {num_tuning_epoch:2.0f}.  Accept rate: {jax.numpy.mean(accept_rate):1.4f}")
            tuning_mhrw = gen_kernel(step_size)
            samples, stats = tfp.mcmc.sample_chain(
                num_results=burnin // num_tuning_epoch,
                current_state=init_state,
                kernel=tuning_mhrw,
                num_burnin_steps=burnin,
                num_steps_between_results=thin,
                trace_fn=trace_is_accepted,
                seed=seed)
            #pdb.set_trace()
            accept_rate = jax.numpy.ravel(stats).mean()
            step_size = vtune(step_size, accept_rate)

        # Run inference
        sample_kernel = gen_kernel(step_size)
        mcmc_samples, mcmc_stats = tfp.mcmc.sample_chain(
            num_results=draws,
            num_burnin_steps=tune // num_tuning_epoch,
            current_state=init_state,
            kernel=sample_kernel,
            trace_fn=trace_is_accepted,
            seed=seed,
        )
        return mcmc_samples, mcmc_stats

    print("Compiling and sampling...")
    tic2 = pd.Timestamp.now()
    #pdb.set_trace()
    map_seed = jax.random.split(seed, chains)
    map_stepsize = jax.tree_map(jax.numpy.ones, chains)

    mcmc_samples, accept_rate = _sample(init_state_batched, map_seed,
                                        map_stepsize)

    # map_seed = jax.random.split(seed, chains)
    # mcmc_samples = _sample(init_state_batched, map_seed)
    # tic4 = pd.Timestamp.now()
    # print("Sampling time = ", tic4 - tic3)

    posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}

    az_trace = az.from_dict(posterior=posterior)
    tic3 = pd.Timestamp.now()
    print("Compilation + sampling time = ", tic3 - tic2)
    return az_trace, accept_rate
Beispiel #4
0
def sample_tfp_nuts(
    draws=1000,
    tune=1000,
    chains=4,
    target_accept=0.8,
    random_seed=10,
    model=None,
    num_tuning_epoch=2,
    num_compute_step_size=500,
):
    import jax

    from tensorflow_probability.substrates import jax as tfp

    model = modelcontext(model)

    seed = jax.random.PRNGKey(random_seed)

    fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
    fns = jax_funcify(fgraph)
    logp_fn_jax = fns[0]

    rv_names = [rv.name for rv in model.free_RVs]
    init_state = [model.test_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

    @jax.pmap
    def _sample(init_state, seed):
        def gen_kernel(step_size):
            hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax,
                                          step_size=step_size)
            return tfp.mcmc.DualAveragingStepSizeAdaptation(
                hmc,
                tune // num_tuning_epoch,
                target_accept_prob=target_accept)

        def trace_fn(_, pkr):
            return pkr.new_step_size

        def get_tuned_stepsize(samples, step_size):
            return step_size[-1] * jax.numpy.std(
                samples[-num_compute_step_size:])

        step_size = jax.tree_map(jax.numpy.ones_like, init_state)
        for i in range(num_tuning_epoch - 1):
            tuning_hmc = gen_kernel(step_size)
            init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain(
                num_results=tune // num_tuning_epoch,
                current_state=init_state,
                kernel=tuning_hmc,
                trace_fn=trace_fn,
                return_final_kernel_results=True,
                seed=seed,
            )

            step_size = jax.tree_multimap(get_tuned_stepsize,
                                          list(init_samples), tuning_result)
            init_state = [x[-1] for x in init_samples]

        # Run inference
        sample_kernel = gen_kernel(step_size)
        mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain(
            num_results=draws,
            num_burnin_steps=tune // num_tuning_epoch,
            current_state=init_state,
            kernel=sample_kernel,
            trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken,
            seed=seed,
        )

        return mcmc_samples, leapfrog_num

    print("Compiling and sampling...")
    tic2 = pd.Timestamp.now()
    map_seed = jax.random.split(seed, chains)
    mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed)

    # map_seed = jax.random.split(seed, chains)
    # mcmc_samples = _sample(init_state_batched, map_seed)
    # tic4 = pd.Timestamp.now()
    # print("Sampling time = ", tic4 - tic3)

    posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}

    az_trace = az.from_dict(posterior=posterior)
    tic3 = pd.Timestamp.now()
    print("Compilation + sampling time = ", tic3 - tic2)
    return az_trace  # , leapfrog_num, tic3 - tic2
Beispiel #5
0
def sample_numpyro_nuts_vmap(draws=1000,
                             tune=1000,
                             chains=4,
                             target_accept=0.8,
                             random_seed=10,
                             model=None,
                             progress_bar=True,
                             chain_method="parallel"):
    from numpyro.infer import MCMC, NUTS

    from pymc3 import modelcontext

    model = modelcontext(model)

    seed = jax.random.PRNGKey(random_seed)

    fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
    fns = jax_funcify(fgraph)
    logp_fn_jax = fns[0]

    rv_names = [rv.name for rv in model.free_RVs]
    init_state = [model.test_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(
        lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

    @jax.jit
    def _sample(current_state, seed):
        step_size = jax.tree_map(jax.numpy.ones_like, init_state)
        nuts_kernel = NUTS(
            potential_fn=lambda x: -logp_fn_jax(*x),
            # model=model,
            target_accept_prob=target_accept,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
        )

        pmap_numpyro = MCMC(
            nuts_kernel,
            num_warmup=tune,
            num_samples=draws,
            num_chains=chains,
            postprocess_fn=None,
            chain_method=chain_method,
            progress_bar=progress_bar,
        )

        pmap_numpyro.run(seed,
                         init_params=current_state,
                         extra_fields=("num_steps", ))
        samples = pmap_numpyro.get_samples(group_by_chain=True)
        leapfrogs_taken = pmap_numpyro.get_extra_fields(
            group_by_chain=True)["num_steps"]
        return samples, leapfrogs_taken

    print("Compiling and sampling...")
    tic2 = pd.Timestamp.now()
    map_seed = jax.random.split(seed, chains)
    mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed)
    # map_seed = jax.random.split(seed, chains)
    # mcmc_samples = _sample(init_state_batched, map_seed)
    # tic4 = pd.Timestamp.now()
    # print("Sampling time = ", tic4 - tic3)

    posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}

    az_trace = az.from_dict(posterior=posterior)
    tic3 = pd.Timestamp.now()
    print("Compilation + sampling time = ", tic3 - tic2)
    return az_trace  # , leapfrogs_taken, tic3 - tic2