def f(xs): def scan_body(carry, _): # closes oves xs return carry + 1, xs[carry] return lax.scan(scan_body, 1, xs)[1]
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() guide = AutoBNAFNormal( dual_moon_model, hidden_factors=[args.hidden_factor, args.hidden_factor]) svi = SVI(dual_moon_model, guide, optim.Adam(0.003), ELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, jnp.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(2), params, sample_shape=(args.num_samples, ))['x'].copy() print("\nStart NeuTra HMC...") neutra = NeuTraReparam(guide, params) neutra_model = neutra.reparam(dual_moon_model) nuts_kernel = NUTS(neutra_model) mcmc = MCMC( nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains, progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True) mcmc.run(random.PRNGKey(3)) mcmc.print_summary() zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"] print("Transform samples into unwarped space...") samples = neutra.transform_sample(zs) print_summary(samples) zs = zs.reshape(-1, 2) samples = samples['x'].reshape(-1, 2).copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(jnp.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = neutra.transform_sample(guide_base_samples)['x'] x1 = jnp.linspace(-3, 3, 100) x2 = jnp.linspace(-3, 3, 100) X1, X2 = jnp.meshgrid(x1, x2) P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 8), constrained_layout=True) gs = GridSpec(2, 3, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[1, 0]) ax3 = fig.add_subplot(gs[0, 1]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[0, 2]) ax6 = fig.add_subplot(gs[1, 2]) ax1.plot(losses[1000:]) ax1.set_title('Autoguide training loss\n(after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nAutoBNAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)' ) ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nvanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the\nwarped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using\nNeuTra HMC sampler') plt.savefig("neutra.pdf")
def rnn(xs): xs = np.swapaxes(xs, 0, 1) _, ys = lax.scan(cell, carry_init(xs.shape[1]), xs) return np.swapaxes(ys, 0, 1)
def scan_fori_loop(lo, hi, loop, init): def scan_f(x, t): return loop(t, x), () x, _ = lax.scan(scan_f, init, jnp.arange(lo, hi)) return x
def func(x): return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
def f(x): def body(carry, x): effect_p.bind(effect='foo') return carry, x return lax.scan(body, x, jnp.arange(4))
def fori_loop(lower, upper, body_fun, init_val): f = lambda x, i: (body_fun(i, x), ()) result, _ = lax.scan(f, init_val, np.arange(lower, upper)) return result
def mLSTM1900_batch( params: Dict[str, np.ndarray], batch: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ LSTM layer implemented according to UniRep, found here: https://github.com/churchlab/UniRep/blob/master/unirep.py#L43, for a batch of data. This function processes a single embedded sequence, passed in as a two dimensional array, with number of rows being number of sequence positions, and the number of columns being the embedding of each sequence letter. :param params: All weights and biases for a single mLSTM1900 RNN cell. :param batch: One sequence embedded in a (n, 10) matrix, where `n` is the number of sequences :returns: """ h_t = np.zeros(params["wmh"].shape[0]) c_t = np.zeros(params["wmh"].shape[0]) def mLSTM1900_step( carry: Tuple[np.ndarray, np.ndarray], x_t: np.ndarray, ) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]: """ Implementation of mLSTMCell from UniRep paper, with weight normalization. Exact source code reference: https://github.com/churchlab/UniRep/blob/master/unirep.py#L75 Shapes of parameters: - wmx: 10, 1900 - wmh: 1900, 1900 - wx: 10, 7600 - wh: 1900, 7600 - gmx: 1900 - gmh: 1900 - gx: 7600 - gh: 7600 - b: 7600 Shapes of inputs: - x_t: (1, 10) - carry: - h_t: (1, 1900) - c_t: (1, 1900) """ h_t, c_t = carry # Perform weight normalization first # (Corresponds to Line 113). # In the original implementation, this is toggled by a boolean flag, # but here we are enabling it by default. wx = l2_normalize(params["wx"], axis=0) * params["gx"] wh = l2_normalize(params["wh"], axis=0) * params["gh"] wmx = l2_normalize(params["wmx"], axis=0) * params["gmx"] wmh = l2_normalize(params["wmh"], axis=0) * params["gmh"] # Shape annotation # (:, 10) @ (10, 1900) * (:, 1900) @ (1900, 1900) => (:, 1900) m = np.matmul(x_t, wmx) * np.matmul(h_t, wmh) # (:, 10) @ (10, 7600) * (:, 1900) @ (1900, 7600) + (7600, ) => (:, 7600) z = np.matmul(x_t, wx) + np.matmul(m, wh) + params["b"] # Splitting along axis 1, four-ways, gets us (:, 1900) as the shape # for each of i, f, o and u i, f, o, u = np.split(z, 4, axis=-1) # input, forget, output, update # Elementwise transforms here. # Shapes are are (:, 1900) for each of the four. i = sigmoid(i, version="exp") f = sigmoid(f, version="exp") o = sigmoid(o, version="exp") u = tanh(u) # (:, 1900) * (:, 1900) + (:, 1900) * (:, 1900) => (:, 1900) c_t = f * c_t + i * u # (:, 1900) * (:, 1900) => (:, 1900) h_t = o * tanh(c_t) # h, c each have shape (:, 1900) return (h_t, c_t), h_t (h_final, c_final), outputs = lax.scan( mLSTM1900_step, init=(h_t, c_t), xs=batch ) return h_final, c_final, outputs
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None): from numpyro.contrib.funsor import config_enumerate, enum, markov from numpyro.contrib.funsor import trace as packed_trace # XXX: This implementation only works for history size=1 but can be # extended to history size > 1 by running `f` `history_size` times # for initialization. However, `sequential_sum_product` does not # support history size > 1, so we skip supporting it here. # Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1. if reverse: x0 = tree_map(lambda x: x[-1], xs) xs_ = tree_map(lambda x: x[:-1], xs) else: x0 = tree_map(lambda x: x[0], xs) xs_ = tree_map(lambda x: x[1:], xs) carry_shape_at_t1 = None def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes( ), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [ jnp.shape(x) for x in tree_flatten(new_carry)[0] ] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y) with markov(): wrapped_carry = (0, rng_key, init) wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0) if length == 1: ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0) return wrapped_carry, (PytreeTrace({}), ys) wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse) first_var = None for name, site in pytree_trace.trace.items(): # currently, we only record sample or deterministic in the trace # we don't need to adjust `dim_to_name` for deterministic site if site['type'] not in ('sample', ): continue # add `time` dimension, the name will be '_time_{first variable in the trace}' if first_var is None: first_var = name # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because # we don't record 1-size dimensions in this field time_dim = -min(len(site['fn'].batch_shape), jnp.ndim(site['value']) - site['fn'].event_dim) site['infer']['dim_to_name'][time_dim] = '_time_{}'.format(first_var) # similar to carry, we need to reshape due to shape alternating in markov ys = tree_multimap( lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys) # we also need to reshape `carry` to match sequential behavior if length % 2 == 0: t, rng_key, carry = wrapped_carry flatten_carry, treedef = tree_flatten(carry) flatten_carry = [ jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape_at_t1) ] carry = tree_unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys)
def iterated_smoother_routine( initial_state: MVNormalParameters, observations: jnp.ndarray, transition_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], transition_covariance: jnp.ndarray, observation_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], observation_covariance: jnp.ndarray, initial_linearization_points: jnp.ndarray = None, n_iter: int = 100, propagate_first: bool = True): """ Computes the Gauss-Newton iterated extended Kalman smoother Parameters ---------- initial_state: MVNormalParameters prior belief on the initial state distribution observations: (n, K) array array of n observations of dimension K transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}` transition function of the state space model transition_covariance: (D, D) array transition covariances for each time step, if passed only one, it is repeated n times observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t` observation function of the state space model observation_covariance: (K, K) array observation error covariances for each time step, if passed only one, it is repeated n times initial_linearization_points: (N, D) array, optional points at which to compute the jacobians durning the first pass. n_iter: int number of times the filter-smoother routine is computed propagate_first: bool, optional Is the first step a transition or an update? i.e. False if the initial time step has an associated observation. Default is True. Returns ------- iterated_smoothed_trajectories: MVNormalParameters The result of the smoothing routine """ def body(linearization_points, _): if linearization_points is not None: linearization_points = linearization_points.mean if isinstance( linearization_points, MVNormalParameters) else linearization_points filtered_states = filter_routine(initial_state, observations, transition_function, transition_covariance, observation_function, observation_covariance, linearization_points, propagate_first) return smoother_routine(transition_function, transition_covariance, filtered_states, linearization_points), None if initial_linearization_points is None: initial_linearization_points = body(None, None)[0] iterated_smoothed_trajectories, _ = lax.scan(body, initial_linearization_points, jnp.arange(n_iter)) return iterated_smoothed_trajectories
def f(x): def body_fun(carry, x): effect_p.bind(effect='while1') return carry, x return lax.scan(body_fun, x, jnp.arange(5))
def main(args): jax_config.update('jax_platform_name', args.device) print("Start vanilla HMC...") nuts_kernel = NUTS(potential_fn=dual_moon_pe) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(11), init_params=np.array([2., 0.])) vanilla_samples = mcmc.get_samples() adam = optim.Adam(0.001) rng_init, rng_train = random.split(random.PRNGKey(1), 2) guide = AutoIAFNormal(dual_moon_model, hidden_dims=[args.num_hidden], skip_connections=True) svi = SVI(dual_moon_model, guide, elbo, adam) svi_state = svi.init(rng_init) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior(random.PRNGKey(0), params, sample_shape=(args.num_samples,)) transform = guide.get_transform(params) unpack_fn = guide.unpack_latent _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model) transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x))) # noqa: E731 init_params = np.zeros(guide.latent_size) print("\nStart NeuTra HMC...") # TODO: exlore why neutra samples are not good # Issue: https://github.com/pyro-ppl/numpyro/issues/256 nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(10), init_params=init_params) zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) summary(tree_map(lambda x: x[None, ...], samples)) # make plots # IAF guide samples (for plotting) iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,)) iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.) ax3.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6) ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()
def f(init): return lax.scan(body, init, np.arange(5.))
def f(xs): def _body(carry, x): debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered) return carry + 1, x + 1 return lax.scan(_body, 2, xs)
def rnn(params, inputs): init_state = np.zeros(n_hid) _, outputs = lax.scan(partial(step, params), init_state, inputs) return outputs
def f(xs): return lax.scan(scan_body, None, xs)
def f_jax(xs, ys): body_const = np.ones((2, ), dtype=np.float32) # Test constant capture def body(res0, inputs): x, y = inputs return res0 + x * y, body_const return lax.scan(body, 0., (xs, ys))
def f(carry, xs): return lax.scan(scan_body, carry, xs)
def scan_enum( f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1, first_available_dim=None, ): from numpyro.contrib.funsor import ( config_enumerate, enum, markov, trace as packed_trace, ) # amount number of steps to unroll history = min(history, length) unroll_steps = min(2 * history - 1, length) if reverse: x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs) xs_ = tree_map(lambda x: x[:-unroll_steps], xs) else: x0 = tree_map(lambda x: x[:unroll_steps], xs) xs_ = tree_map(lambda x: x[unroll_steps:], xs) carry_shapes = [] def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config(f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry ) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) with handlers.block( hide_fn=lambda site: not site["name"].startswith("_PREV_") ), enum(first_available_dim=first_available_dim): wrapped_carry = (0, rng_key, init) y0s = [] # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan` for i in markov(range(unroll_steps + 1), history=history): if i < unroll_steps: wrapped_carry, (_, y0) = body_fn( wrapped_carry, tree_map(lambda z: z[i], x0) ) if i > 0: # reshape y1, y2,... to have the same shape as y0 y0 = tree_multimap( lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0 ) y0s.append(y0) # shapes of the first `history - 1` steps are not useful to interpret the last carry # shape so we don't need to record them here if (i >= history - 1) and (len(carry_shapes) < history + 1): carry_shapes.append( jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0] ) else: # this is the last rolling step y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s) # return early if length = unroll_steps if length == unroll_steps: return wrapped_carry, (PytreeTrace({}), y0s) wrapped_carry = device_put(wrapped_carry) wrapped_carry, (pytree_trace, ys) = lax.scan( body_fn, wrapped_carry, xs_, length - unroll_steps, reverse ) first_var = None for name, site in pytree_trace.trace.items(): # currently, we only record sample or deterministic in the trace # we don't need to adjust `dim_to_name` for deterministic site if site["type"] not in ("sample",): continue # add `time` dimension, the name will be '_time_{first variable in the trace}' if first_var is None: first_var = name # we haven't promote shapes of values yet during `lax.scan`, so we do it here site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because # we don't record 1-size dimensions in this field time_dim = -min( len(site["fn"].batch_shape), jnp.ndim(site["value"]) - site["fn"].event_dim ) site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var) # similar to carry, we need to reshape due to shape alternating in markov ys = tree_multimap( lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys ) # then join with y0s ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) # we also need to reshape `carry` to match sequential behavior i = (length + 1) % (history + 1) t, rng_key, carry = wrapped_carry carry_shape = carry_shapes[i] flatten_carry, treedef = tree_flatten(carry) flatten_carry = [ jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape) ] carry = tree_unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys)
def f(x): @jax.named_scope('scan_body') def body(carry, x): return carry * x, carry + x return lax.scan(body, x, jnp.arange(8.))[0]
def g(x): return lax.scan( lambda carry, inp: (carry + f(inp), 0.), np.full(x.shape[1:], 0.), # Like x w/o leading dim x)[0]
def cumsum(arr): out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr) return out
def f(): return lax.scan(err, (), (), 3)
def update_ptest_single_scan(carry,rng): return scan(update_ptest_single,carry,rng)
def execute_single_node(hidden_state, node_embedding): carry, _ = lax.scan(lstm, hidden_state, node_embedding) return carry
def update_pn_scan(carry,rng): return scan(update_pn,carry,rng)
def trace(state, fn, num_steps, unroll, **_): """Implementation of `trace` operator, without the calling convention.""" # We need the shapes and dtypes of the outputs of `fn`. _, untraced_spec, traced_spec = jax.eval_shape( fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state)) untraced_init = map_tree(lambda spec: jnp.zeros(spec.shape, spec.dtype), untraced_spec) try: num_steps = int(num_steps) use_scan = True except TypeError: use_scan = False if flatten_tree(traced_spec): raise ValueError( 'Cannot trace values when `num_steps` is not statically known. Pass ' 'False to `trace_mask` or return an empty structure (e.g. `()`) as ' 'the extra output.') if unroll: raise ValueError( 'Cannot unroll when `num_steps` is not statically known.') if unroll: traced_lists = map_tree(lambda _: [], traced_spec) untraced = untraced_init for _ in range(num_steps): state, untraced, traced_element = fn(state) map_tree_up_to(traced_spec, lambda l, e: l.append(e), traced_lists, traced_element) # Using asarray instead of stack to handle empty arrays correctly. traced = map_tree_up_to(traced_spec, lambda l, s: jnp.asarray(l, dtype=s.dtype), traced_lists, traced_spec) elif use_scan: def wrapper(state_untraced, _): state, _ = state_untraced state, untraced, traced = fn(state) return (state, untraced), traced (state, untraced), traced = lax.scan( wrapper, (state, untraced_init), xs=None, length=num_steps, ) else: trace_arrays = map_tree( lambda spec: jnp.zeros((num_steps, ) + spec.shape, spec.dtype), traced_spec) def wrapper(i, state_untraced_traced): state, _, trace_arrays = state_untraced_traced state, untraced, traced = fn(state) trace_arrays = map_tree(lambda a, e: a.at[i].set(e), trace_arrays, traced) return (state, untraced, trace_arrays) state, untraced, traced = lax.fori_loop( jnp.asarray(0, num_steps.dtype), num_steps, wrapper, (state, untraced_init, trace_arrays), ) return state, untraced, traced
def main(args): print("Start vanilla HMC...") nuts_kernel = NUTS(dual_moon_model) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) mcmc.run(random.PRNGKey(0)) mcmc.print_summary() vanilla_samples = mcmc.get_samples()['x'].copy() adam = optim.Adam(0.01) # TODO: it is hard to find good hyperparameters such that IAF guide can learn this model. # We will use BNAF instead! guide = AutoIAFNormal(dual_moon_model, num_flows=2, hidden_dims=[args.num_hidden, args.num_hidden]) svi = SVI(dual_moon_model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(1)) print("Start training guide...") last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters)) params = svi.get_params(last_state) print("Finish training guide. Extract samples...") guide_samples = guide.sample_posterior( random.PRNGKey(0), params, sample_shape=(args.num_samples, ))['x'].copy() transform = guide.get_transform(params) _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), dual_moon_model) transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform) transformed_constrain_fn = lambda x: constrain_fn(transform(x) ) # noqa: E731 print("\nStart NeuTra HMC...") nuts_kernel = NUTS(potential_fn=transformed_potential_fn) mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples) init_params = np.zeros(guide.latent_size) mcmc.run(random.PRNGKey(3), init_params=init_params) mcmc.print_summary() zs = mcmc.get_samples() print("Transform samples into unwarped space...") samples = vmap(transformed_constrain_fn)(zs) print_summary(tree_map(lambda x: x[None, ...], samples)) samples = samples['x'].copy() # make plots # guide samples (for plotting) guide_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(4), (1000, )) guide_trans_samples = vmap(transformed_constrain_fn)( guide_base_samples)['x'] x1 = np.linspace(-3, 3, 100) x2 = np.linspace(-3, 3, 100) X1, X2 = np.meshgrid(x1, x2) P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1))) fig = plt.figure(figsize=(12, 16), constrained_layout=True) gs = GridSpec(3, 2, figure=fig) ax1 = fig.add_subplot(gs[0, 0]) ax2 = fig.add_subplot(gs[0, 1]) ax3 = fig.add_subplot(gs[1, 0]) ax4 = fig.add_subplot(gs[1, 1]) ax5 = fig.add_subplot(gs[2, 0]) ax6 = fig.add_subplot(gs[2, 1]) ax1.plot(np.log(losses[1000:])) ax1.set_title('Autoguide training log loss (after 1000 steps)') ax2.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2) ax2.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') sns.scatterplot(guide_base_samples[:, 0], guide_base_samples[:, 1], ax=ax3, hue=guide_trans_samples[:, 0] < 0.) ax3.set( xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)') ax4.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], n_levels=30, ax=ax4) ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) ax4.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples[:, 0] < 0., s=30, alpha=0.5, edgecolor="none") ax5.set(xlim=[-5, 5], ylim=[-5, 5], xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)') ax6.contourf(X1, X2, P, cmap='OrRd') sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6) ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2) ax6.set(xlim=[-3, 3], ylim=[-3, 3], xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') plt.savefig("neutra.pdf") plt.close()
def f1(xs): res, _ = lax.scan(lambda carry, x: (carry + x + const, None), np.zeros((256,), dtype=np.float32), xs) return res
def hmm_viterbi_log(params, obs_seq, length=None): ''' Computes, for each time step, the marginal conditional probability that the Hidden Markov Model was in each possible state given the observations that were made at each time step, i.e. P(z[i] | x[0], ..., x[num_steps - 1]) for all i from 0 to num_steps - 1 It is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm.py Parameters ---------- params : HMM Hidden Markov Model obs_seq: array(seq_len) History of observed states Returns ------- * array(seq_len, n_states) Alpha values * array(seq_len, n_states) Beta values * array(seq_len, n_states) Marginal conditional probability * float The loglikelihood giving log(p(x|model)) ''' seq_len = len(obs_seq) if length is None: length = seq_len trans_dist, obs_dist, init_dist = params.trans_dist, params.obs_dist, params.init_dist trans_log_probs = log_softmax(trans_dist.logits) init_log_probs = log_softmax(init_dist.logits) n_states = obs_dist.batch_shape[0] first_log_prob = init_log_probs + obs_dist.log_prob(obs_seq[0]) if seq_len == 1: return jnp.expand_dims(jnp.argmax(first_log_prob), axis=0) def viterbi_forward(prev_logp, t): obs_logp = obs_dist.log_prob(obs_seq[t]) logp = jnp.where( t <= length, prev_logp[..., None] + trans_log_probs + obs_logp[..., None, :], -jnp.inf + jnp.zeros_like(trans_log_probs)) max_logp_given_successor = jnp.where(t <= length, jnp.max(logp, axis=-2), prev_logp) most_likely_given_successor = jnp.where(t <= length, jnp.argmax(logp, axis=-2), -1) return max_logp_given_successor, most_likely_given_successor ts = jnp.arange(1, seq_len) final_log_prob, most_likely_sources = lax.scan(viterbi_forward, first_log_prob, ts) most_likely_initial_given_successor = jnp.argmax(trans_log_probs + first_log_prob, axis=-2) most_likely_sources = jnp.concatenate([ jnp.expand_dims(most_likely_initial_given_successor, axis=0), most_likely_sources ], axis=0) def viterbi_backward(state, t): state = jnp.where( t <= length, jnp.sum(most_likely_sources[t] * one_hot(state, n_states)).astype( jnp.int64), state) most_likely = jnp.where(t <= length, state, -1) return state, most_likely final_state = jnp.argmax(final_log_prob) _, most_likely_path = lax.scan(viterbi_backward, final_state, ts, reverse=True) final_state = jnp.where(length == seq_len, final_state, -1) return jnp.append(most_likely_path, final_state)