def get_initial_position(rng_key, model, num_chains, **kwargs): conditioning_vars = set(kwargs.keys()) model_randvars = set(model.random_variables) to_sample_vars = model_randvars.difference(conditioning_vars) samples = sample_forward(rng_key, model, num_samples=num_chains, **kwargs) initial_positions = dict((var, samples[var]) for var in to_sample_vars) # A naive way to go about flattening the positions is to transform the # dictionary of arrays that contain the parameter value to a list of # dictionaries, one per position and then unravel the dictionaries. # However, this approach takes more time than getting the samples in the # first place. # # Luckily, JAX first sorts dictionaries by keys # (https://github.com/google/jax/blob/master/jaxlib/pytree.cc) when # raveling pytrees. We can thus ravel and stack parameter values in an # array, sorting by key; this gives our flattened positions. We then build # a single dictionary that contains the parameters value and use it to get # the unraveling function using `unravel_pytree`. positions = np.stack( [np.ravel(samples[s]) for s in sorted(initial_positions.keys())], axis=1 ) # np.atleast_1d is necessary to handle single chains sample_position_dict = { parameter: np.atleast_1d(values)[0] for parameter, values in initial_positions.items() } _, unravel_fn = jax_ravel_pytree(sample_position_dict) return positions, unravel_fn
def get_initial_position(rng_key, model, model_args, observations, num_chains): """Get an initial position for the chain. While there surely are smarter way to initialize the chain we sample the first position from the prior joint distribution of the variables. """ initial_positions = mcx.sample_joint( rng_key, model, model_args, num_samples=num_chains ) for observed_var in observations.keys(): initial_positions.pop(observed_var) # MCX's inference algorithms work on flat arrays, we thus have to ravel the # positions before feeding them to the evaluators. We need to ravel the # positions *for each chain separately*. However, if we naively use JAX's # `ravel_pytree` function on the dictionary with prior samples we will obtain # a single array with all positions for all chains. reshaped_positions = jax.tree_util.tree_map( lambda x: x.reshape(num_chains, -1), initial_positions ) flattened_positions = jnp.concatenate( jax.tree_util.tree_leaves(reshaped_positions), axis=1 ) # We will use jax.vmap to map the computation over the different chains. The unravelling # function thus needs to unravel a single chain. `jnp.atleast_1d` is necessary to handle # the case where we sample a single chains. if num_chains == 1: sample_position_dict = initial_positions else: sample_position_dict = jax.tree_util.tree_map(lambda x: x[0], initial_positions) _, unravel_fn = jax_ravel_pytree(sample_position_dict) return flattened_positions, unravel_fn