コード例 #1
0
ファイル: checkify_test.py プロジェクト: 0x0is1/jax
        def f(xs):
            def scan_body(carry, _):
                # closes oves xs
                return carry + 1, xs[carry]

            return lax.scan(scan_body, 1, xs)[1]
コード例 #2
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    guide = AutoBNAFNormal(
        dual_moon_model,
        hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), ELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, jnp.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(2), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    print("\nStart NeuTra HMC...")
    neutra = NeuTraReparam(guide, params)
    neutra_model = neutra.reparam(dual_moon_model)
    nuts_kernel = NUTS(neutra_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(3))
    mcmc.print_summary()
    zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"]
    print("Transform samples into unwarped space...")
    samples = neutra.transform_sample(zs)
    print_summary(samples)
    zs = zs.reshape(-1, 2)
    samples = samples['x'].reshape(-1, 2).copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(jnp.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = neutra.transform_sample(guide_base_samples)['x']

    x1 = jnp.linspace(-3, 3, 100)
    x2 = jnp.linspace(-3, 3, 100)
    X1, X2 = jnp.meshgrid(x1, x2)
    P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(losses[1000:])
    ax1.set_title('Autoguide training loss\n(after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nAutoBNAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)'
    )

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nvanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the\nwarped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nNeuTra HMC sampler')

    plt.savefig("neutra.pdf")
コード例 #3
0
 def rnn(xs):
     xs = np.swapaxes(xs, 0, 1)
     _, ys = lax.scan(cell, carry_init(xs.shape[1]), xs)
     return np.swapaxes(ys, 0, 1)
コード例 #4
0
ファイル: control.py プロジェクト: warrendeng/cilqr_jax
def scan_fori_loop(lo, hi, loop, init):
    def scan_f(x, t):
        return loop(t, x), ()

    x, _ = lax.scan(scan_f, init, jnp.arange(lo, hi))
    return x
コード例 #5
0
 def func(x):
     return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
コード例 #6
0
 def f(x):
   def body(carry, x):
     effect_p.bind(effect='foo')
     return carry, x
   return lax.scan(body, x, jnp.arange(4))
コード例 #7
0
def fori_loop(lower, upper, body_fun, init_val):
    f = lambda x, i: (body_fun(i, x), ())
    result, _ = lax.scan(f, init_val, np.arange(lower, upper))
    return result
コード例 #8
0
ファイル: layers.py プロジェクト: hypostulate/jax-unirep
def mLSTM1900_batch(
    params: Dict[str, np.ndarray], batch: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    LSTM layer implemented according to UniRep,
    found here:
    https://github.com/churchlab/UniRep/blob/master/unirep.py#L43,
    for a batch of data.

    This function processes a single embedded sequence,
    passed in as a two dimensional array,
    with number of rows being number of sequence positions,
    and the number of columns being the embedding of each sequence letter.

    :param params: All weights and biases for a single
        mLSTM1900 RNN cell.
    :param batch: One sequence embedded in a (n, 10) matrix,
        where `n` is the number of sequences
    :returns:
    """
    h_t = np.zeros(params["wmh"].shape[0])
    c_t = np.zeros(params["wmh"].shape[0])

    def mLSTM1900_step(
        carry: Tuple[np.ndarray, np.ndarray],
        x_t: np.ndarray,
    ) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]:
        """
        Implementation of mLSTMCell from UniRep paper, with weight normalization.

        Exact source code reference:
        https://github.com/churchlab/UniRep/blob/master/unirep.py#L75

        Shapes of parameters:

        - wmx: 10, 1900
        - wmh: 1900, 1900
        - wx: 10, 7600
        - wh: 1900, 7600
        - gmx: 1900
        - gmh: 1900
        - gx: 7600
        - gh: 7600
        - b: 7600

        Shapes of inputs:

        - x_t: (1, 10)
        - carry:
            - h_t: (1, 1900)
            - c_t: (1, 1900)
        """
        h_t, c_t = carry

        # Perform weight normalization first
        # (Corresponds to Line 113).
        # In the original implementation, this is toggled by a boolean flag,
        # but here we are enabling it by default.
        wx = l2_normalize(params["wx"], axis=0) * params["gx"]
        wh = l2_normalize(params["wh"], axis=0) * params["gh"]
        wmx = l2_normalize(params["wmx"], axis=0) * params["gmx"]
        wmh = l2_normalize(params["wmh"], axis=0) * params["gmh"]

        # Shape annotation
        # (:, 10) @ (10, 1900) * (:, 1900) @ (1900, 1900) => (:, 1900)
        m = np.matmul(x_t, wmx) * np.matmul(h_t, wmh)

        # (:, 10) @ (10, 7600) * (:, 1900) @ (1900, 7600) + (7600, ) => (:, 7600)
        z = np.matmul(x_t, wx) + np.matmul(m, wh) + params["b"]

        # Splitting along axis 1, four-ways, gets us (:, 1900) as the shape
        # for each of i, f, o and u
        i, f, o, u = np.split(z, 4, axis=-1)  # input, forget, output, update

        # Elementwise transforms here.
        # Shapes are are (:, 1900) for each of the four.
        i = sigmoid(i, version="exp")
        f = sigmoid(f, version="exp")
        o = sigmoid(o, version="exp")
        u = tanh(u)

        # (:, 1900) * (:, 1900) + (:, 1900) * (:, 1900) => (:, 1900)
        c_t = f * c_t + i * u

        # (:, 1900) * (:, 1900) => (:, 1900)
        h_t = o * tanh(c_t)

        # h, c each have shape (:, 1900)
        return (h_t, c_t), h_t

    (h_final, c_final), outputs = lax.scan(
        mLSTM1900_step, init=(h_t, c_t), xs=batch
    )
    return h_final, c_final, outputs
コード例 #9
0
def scan_enum(f,
              init,
              xs,
              length,
              reverse,
              rng_key=None,
              substitute_stack=None):
    from numpyro.contrib.funsor import config_enumerate, enum, markov
    from numpyro.contrib.funsor import trace as packed_trace

    # XXX: This implementation only works for history size=1 but can be
    # extended to history size > 1 by running `f` `history_size` times
    # for initialization. However, `sequential_sum_product` does not
    # support history size > 1, so we skip supporting it here.
    # Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1.
    if reverse:
        x0 = tree_map(lambda x: x[-1], xs)
        xs_ = tree_map(lambda x: x[:-1], xs)
    else:
        x0 = tree_map(lambda x: x[0], xs)
        xs_ = tree_map(lambda x: x[1:], xs)

    carry_shape_at_t1 = None

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(
            ), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [
                jnp.shape(x) for x in tree_flatten(new_carry)[0]
            ]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)

    with markov():
        wrapped_carry = (0, rng_key, init)
        wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0)
        if length == 1:
            ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0)
            return wrapped_carry, (PytreeTrace({}), ys)
        wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry,
                                                     xs_, length - 1, reverse)

    first_var = None
    for name, site in pytree_trace.trace.items():
        # currently, we only record sample or deterministic in the trace
        # we don't need to adjust `dim_to_name` for deterministic site
        if site['type'] not in ('sample', ):
            continue
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name

        # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
        # we don't record 1-size dimensions in this field
        time_dim = -min(len(site['fn'].batch_shape),
                        jnp.ndim(site['value']) - site['fn'].event_dim)
        site['infer']['dim_to_name'][time_dim] = '_time_{}'.format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(
        lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys)
    # we also need to reshape `carry` to match sequential behavior
    if length % 2 == 0:
        t, rng_key, carry = wrapped_carry
        flatten_carry, treedef = tree_flatten(carry)
        flatten_carry = [
            jnp.reshape(x, t1_shape)
            for x, t1_shape in zip(flatten_carry, carry_shape_at_t1)
        ]
        carry = tree_unflatten(treedef, flatten_carry)
        wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)
コード例 #10
0
def iterated_smoother_routine(
        initial_state: MVNormalParameters,
        observations: jnp.ndarray,
        transition_function: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
        transition_covariance: jnp.ndarray,
        observation_function: Callable[[jnp.ndarray, jnp.ndarray],
                                       jnp.ndarray],
        observation_covariance: jnp.ndarray,
        initial_linearization_points: jnp.ndarray = None,
        n_iter: int = 100,
        propagate_first: bool = True):
    """
    Computes the Gauss-Newton iterated extended Kalman smoother

    Parameters
    ----------
    initial_state: MVNormalParameters
        prior belief on the initial state distribution
    observations: (n, K) array
        array of n observations of dimension K
    transition_function: callable :math:`f(x_t,\epsilon_t)\mapsto x_{t-1}`
        transition function of the state space model
    transition_covariance: (D, D) array
        transition covariances for each time step, if passed only one, it is repeated n times
    observation_function: callable :math:`h(x_t,\epsilon_t)\mapsto y_t`
        observation function of the state space model
    observation_covariance: (K, K)  array
        observation error covariances for each time step, if passed only one, it is repeated n times
    initial_linearization_points: (N, D) array, optional
        points at which to compute the jacobians durning the first pass.
    n_iter: int
        number of times the filter-smoother routine is computed
    propagate_first: bool, optional
        Is the first step a transition or an update? i.e. False if the initial time step has
        an associated observation. Default is True.
    Returns
    -------
    iterated_smoothed_trajectories: MVNormalParameters
        The result of the smoothing routine

    """
    def body(linearization_points, _):
        if linearization_points is not None:
            linearization_points = linearization_points.mean if isinstance(
                linearization_points,
                MVNormalParameters) else linearization_points
        filtered_states = filter_routine(initial_state, observations,
                                         transition_function,
                                         transition_covariance,
                                         observation_function,
                                         observation_covariance,
                                         linearization_points, propagate_first)
        return smoother_routine(transition_function, transition_covariance,
                                filtered_states, linearization_points), None

    if initial_linearization_points is None:
        initial_linearization_points = body(None, None)[0]

    iterated_smoothed_trajectories, _ = lax.scan(body,
                                                 initial_linearization_points,
                                                 jnp.arange(n_iter))
    return iterated_smoothed_trajectories
コード例 #11
0
        def f(x):
            def body_fun(carry, x):
                effect_p.bind(effect='while1')
                return carry, x

            return lax.scan(body_fun, x, jnp.arange(5))
コード例 #12
0
def main(args):
    jax_config.update('jax_platform_name', args.device)

    print("Start vanilla HMC...")
    nuts_kernel = NUTS(potential_fn=dual_moon_pe)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    mcmc.run(random.PRNGKey(11), init_params=np.array([2., 0.]))
    vanilla_samples = mcmc.get_samples()

    adam = optim.Adam(0.001)
    rng_init, rng_train = random.split(random.PRNGKey(1), 2)
    guide = AutoIAFNormal(dual_moon_model, hidden_dims=[args.num_hidden], skip_connections=True)
    svi = SVI(dual_moon_model, guide, elbo, adam)
    svi_state = svi.init(rng_init)

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
                                           sample_shape=(args.num_samples,))

    transform = guide.get_transform(params)
    unpack_fn = guide.unpack_latent

    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model)
    transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn)
    transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x)))  # noqa: E731

    init_params = np.zeros(guide.latent_size)
    print("\nStart NeuTra HMC...")
    # TODO: exlore why neutra samples are not good
    # Issue: https://github.com/pyro-ppl/numpyro/issues/256
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    mcmc.run(random.PRNGKey(10), init_params=init_params)
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    summary(tree_map(lambda x: x[None, ...], samples))

    # make plots

    # IAF guide samples (for plotting)
    iaf_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(0), (1000,))
    iaf_trans_samples = vmap(transformed_constrain_fn)(iaf_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.)

    fig = plt.figure(figsize=(12, 16), constrained_layout=True)
    gs = GridSpec(3, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[2, 0])
    ax6 = fig.add_subplot(gs[2, 1])

    ax1.plot(np.log(losses[1000:]))
    ax1.set_title('Autoguide training log loss (after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples['x'][:, 0].copy(), guide_samples['x'][:, 1].copy(), n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide')

    sns.scatterplot(iaf_base_samples[:, 0], iaf_base_samples[:, 1], ax=ax3, hue=iaf_trans_samples[:, 0] < 0.)
    ax3.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='AutoIAFNormal base samples (True=left moon; False=right moon)')

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0].copy(), vanilla_samples[:, 1].copy(), n_levels=30, ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5)
    ax4.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler')

    sns.scatterplot(zs[:, 0], zs[:, 1], ax=ax5, hue=samples['x'][:, 0] < 0.,
                    s=30, alpha=0.5, edgecolor="none")
    ax5.set(xlim=[-5, 5], ylim=[-5, 5],
            xlabel='x0', ylabel='x1', title='Samples from the warped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples['x'][:, 0].copy(), samples['x'][:, 1].copy(), n_levels=30, ax=ax6)
    ax6.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3], ylim=[-3, 3],
            xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler')

    plt.savefig("neutra.pdf")
    plt.close()
コード例 #13
0
 def f(init):
     return lax.scan(body, init, np.arange(5.))
コード例 #14
0
 def f(xs):
   def _body(carry, x):
     debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered)
     return carry + 1, x + 1
   return lax.scan(_body, 2, xs)
コード例 #15
0
 def rnn(params, inputs):
     init_state = np.zeros(n_hid)
     _, outputs = lax.scan(partial(step, params), init_state, inputs)
     return outputs
コード例 #16
0
 def f(xs):
     return lax.scan(scan_body, None, xs)
コード例 #17
0
ファイル: control_flow_ops_test.py プロジェクト: stilling/jax
 def f_jax(xs, ys):
   body_const = np.ones((2, ), dtype=np.float32)  # Test constant capture
   def body(res0, inputs):
     x, y = inputs
     return res0 + x * y, body_const
   return lax.scan(body, 0., (xs, ys))
コード例 #18
0
 def f(carry, xs):
     return lax.scan(scan_body, carry, xs)
コード例 #19
0
ファイル: scan.py プロジェクト: pyro-ppl/numpyro
def scan_enum(
    f,
    init,
    xs,
    length,
    reverse,
    rng_key=None,
    substitute_stack=None,
    history=1,
    first_available_dim=None,
):
    from numpyro.contrib.funsor import (
        config_enumerate,
        enum,
        markov,
        trace as packed_trace,
    )

    # amount number of steps to unroll
    history = min(history, length)
    unroll_steps = min(2 * history - 1, length)
    if reverse:
        x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs)
        xs_ = tree_map(lambda x: x[:-unroll_steps], xs)
    else:
        x0 = tree_map(lambda x: x[:unroll_steps], xs)
        xs_ = tree_map(lambda x: x[unroll_steps:], xs)

    carry_shapes = []

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        # we need to tell unconstrained messenger in potential energy computation
        # that only the item at time `i` is needed when transforming
        fn = handlers.infer_config(f, config_fn=lambda msg: {"_scan_current_index": i})

        seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == "condition":
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == "substitute":
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            # handler the name to match the pattern of sakkar_bilmes product
            with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""):
                new_carry, y = config_enumerate(seeded_fn)(carry, x)
                trace = {}
        else:
            # Like scan_wrapper, we collect the trace of scan's transition function
            # `seeded_fn` here. To put time dimension to the correct position, we need to
            # promote shapes to make `fn` and `value`
            # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
            # and value's batch_shape is (3,), then we promote shape of
            # value so that its batch shape is (1, 3)).
            # Here we will promote `fn` shape first. `value` shape will be promoted after scanned.
            # We don't promote `value` shape here because we need to store carry shape
            # at this step. If we reshape the `value` here, output carry might get wrong shape.
            with _promote_fn_shapes(), packed_trace() as trace:
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            if len(carry_shapes) < (history + 1):
                carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]])
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry
            )
        return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

    with handlers.block(
        hide_fn=lambda site: not site["name"].startswith("_PREV_")
    ), enum(first_available_dim=first_available_dim):
        wrapped_carry = (0, rng_key, init)
        y0s = []
        # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`
        for i in markov(range(unroll_steps + 1), history=history):
            if i < unroll_steps:
                wrapped_carry, (_, y0) = body_fn(
                    wrapped_carry, tree_map(lambda z: z[i], x0)
                )
                if i > 0:
                    # reshape y1, y2,... to have the same shape as y0
                    y0 = tree_multimap(
                        lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0
                    )
                y0s.append(y0)
                # shapes of the first `history - 1` steps are not useful to interpret the last carry
                # shape so we don't need to record them here
                if (i >= history - 1) and (len(carry_shapes) < history + 1):
                    carry_shapes.append(
                        jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]
                    )
            else:
                # this is the last rolling step
                y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s)
                # return early if length = unroll_steps
                if length == unroll_steps:
                    return wrapped_carry, (PytreeTrace({}), y0s)
                wrapped_carry = device_put(wrapped_carry)
                wrapped_carry, (pytree_trace, ys) = lax.scan(
                    body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
                )

    first_var = None
    for name, site in pytree_trace.trace.items():
        # currently, we only record sample or deterministic in the trace
        # we don't need to adjust `dim_to_name` for deterministic site
        if site["type"] not in ("sample",):
            continue
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name

        # we haven't promote shapes of values yet during `lax.scan`, so we do it here
        site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"])

        # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
        # we don't record 1-size dimensions in this field
        time_dim = -min(
            len(site["fn"].batch_shape), jnp.ndim(site["value"]) - site["fn"].event_dim
        )
        site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(
        lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys
    )
    # then join with y0s
    ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
    # we also need to reshape `carry` to match sequential behavior
    i = (length + 1) % (history + 1)
    t, rng_key, carry = wrapped_carry
    carry_shape = carry_shapes[i]
    flatten_carry, treedef = tree_flatten(carry)
    flatten_carry = [
        jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape)
    ]
    carry = tree_unflatten(treedef, flatten_carry)
    wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)
コード例 #20
0
ファイル: name_stack_test.py プロジェクト: xueeinstein/jax
        def f(x):
            @jax.named_scope('scan_body')
            def body(carry, x):
                return carry * x, carry + x

            return lax.scan(body, x, jnp.arange(8.))[0]
コード例 #21
0
 def g(x):
     return lax.scan(
         lambda carry, inp: (carry + f(inp), 0.),
         np.full(x.shape[1:], 0.),  # Like x w/o leading dim
         x)[0]
コード例 #22
0
ファイル: masking_test.py プロジェクト: qiuminxu/jax
 def cumsum(arr):
     out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
     return out
コード例 #23
0
ファイル: errors_test.py プロジェクト: uafpdivad/jax
 def f():
     return lax.scan(err, (), (), 3)
コード例 #24
0
ファイル: mv_copula_density_t.py プロジェクト: edfong/MP
def update_ptest_single_scan(carry,rng):
    return scan(update_ptest_single,carry,rng)
コード例 #25
0
 def execute_single_node(hidden_state, node_embedding):
     carry, _ = lax.scan(lstm, hidden_state, node_embedding)
     return carry
コード例 #26
0
ファイル: mv_copula_density_t.py プロジェクト: edfong/MP
def update_pn_scan(carry,rng):
    return scan(update_pn,carry,rng)
コード例 #27
0
def trace(state, fn, num_steps, unroll, **_):
    """Implementation of `trace` operator, without the calling convention."""
    # We need the shapes and dtypes of the outputs of `fn`.
    _, untraced_spec, traced_spec = jax.eval_shape(
        fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state))
    untraced_init = map_tree(lambda spec: jnp.zeros(spec.shape, spec.dtype),
                             untraced_spec)

    try:
        num_steps = int(num_steps)
        use_scan = True
    except TypeError:
        use_scan = False
        if flatten_tree(traced_spec):
            raise ValueError(
                'Cannot trace values when `num_steps` is not statically known. Pass '
                'False to `trace_mask` or return an empty structure (e.g. `()`) as '
                'the extra output.')
        if unroll:
            raise ValueError(
                'Cannot unroll when `num_steps` is not statically known.')

    if unroll:
        traced_lists = map_tree(lambda _: [], traced_spec)
        untraced = untraced_init
        for _ in range(num_steps):
            state, untraced, traced_element = fn(state)
            map_tree_up_to(traced_spec, lambda l, e: l.append(e), traced_lists,
                           traced_element)
        # Using asarray instead of stack to handle empty arrays correctly.
        traced = map_tree_up_to(traced_spec,
                                lambda l, s: jnp.asarray(l, dtype=s.dtype),
                                traced_lists, traced_spec)
    elif use_scan:

        def wrapper(state_untraced, _):
            state, _ = state_untraced
            state, untraced, traced = fn(state)
            return (state, untraced), traced

        (state, untraced), traced = lax.scan(
            wrapper,
            (state, untraced_init),
            xs=None,
            length=num_steps,
        )
    else:
        trace_arrays = map_tree(
            lambda spec: jnp.zeros((num_steps, ) + spec.shape, spec.dtype),
            traced_spec)

        def wrapper(i, state_untraced_traced):
            state, _, trace_arrays = state_untraced_traced
            state, untraced, traced = fn(state)
            trace_arrays = map_tree(lambda a, e: a.at[i].set(e), trace_arrays,
                                    traced)
            return (state, untraced, trace_arrays)

        state, untraced, traced = lax.fori_loop(
            jnp.asarray(0, num_steps.dtype),
            num_steps,
            wrapper,
            (state, untraced_init, trace_arrays),
        )
    return state, untraced, traced
コード例 #28
0
ファイル: neutra.py プロジェクト: ziatdinovmax/numpyro
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    adam = optim.Adam(0.01)
    # TODO: it is hard to find good hyperparameters such that IAF guide can learn this model.
    # We will use BNAF instead!
    guide = AutoIAFNormal(dual_moon_model,
                          num_flows=2,
                          hidden_dims=[args.num_hidden, args.num_hidden])
    svi = SVI(dual_moon_model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(0), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2),
                                                     dual_moon_model)
    transformed_potential_fn = partial(transformed_potential_energy,
                                       potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x)
                                                      )  # noqa: E731

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))
    samples = samples['x'].copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(np.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = vmap(transformed_constrain_fn)(
        guide_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 16), constrained_layout=True)
    gs = GridSpec(3, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[2, 0])
    ax6 = fig.add_subplot(gs[2, 1])

    ax1.plot(np.log(losses[1000:]))
    ax1.set_title('Autoguide training log loss (after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using AutoIAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoIAFNormal base samples (True=left moon; False=right moon)')

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using vanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the warped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using NeuTra HMC sampler')

    plt.savefig("neutra.pdf")
    plt.close()
コード例 #29
0
ファイル: jax2tf_test.py プロジェクト: frederikwilde/jax
 def f1(xs):
   res, _ = lax.scan(lambda carry, x: (carry + x + const, None),
                     np.zeros((256,), dtype=np.float32), xs)
   return res
コード例 #30
0
ファイル: hmm_lib_log.py プロジェクト: qiuhuachuan/pyprobml
def hmm_viterbi_log(params, obs_seq, length=None):
    '''
    Computes, for each time step, the marginal conditional probability that the Hidden Markov Model was
    in each possible state given the observations that were made at each time step, i.e.
    P(z[i] | x[0], ..., x[num_steps - 1]) for all i from 0 to num_steps - 1
    It is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm.py

    Parameters
    ----------
    params : HMM
        Hidden Markov Model
    obs_seq: array(seq_len)
        History of observed states
    Returns
    -------
    * array(seq_len, n_states)
        Alpha values
    * array(seq_len, n_states)
        Beta values
    * array(seq_len, n_states)
        Marginal conditional probability
    * float
        The loglikelihood giving log(p(x|model))
    '''
    seq_len = len(obs_seq)

    if length is None:
        length = seq_len

    trans_dist, obs_dist, init_dist = params.trans_dist, params.obs_dist, params.init_dist

    trans_log_probs = log_softmax(trans_dist.logits)
    init_log_probs = log_softmax(init_dist.logits)

    n_states = obs_dist.batch_shape[0]

    first_log_prob = init_log_probs + obs_dist.log_prob(obs_seq[0])

    if seq_len == 1:
        return jnp.expand_dims(jnp.argmax(first_log_prob), axis=0)

    def viterbi_forward(prev_logp, t):
        obs_logp = obs_dist.log_prob(obs_seq[t])

        logp = jnp.where(
            t <= length,
            prev_logp[..., None] + trans_log_probs + obs_logp[..., None, :],
            -jnp.inf + jnp.zeros_like(trans_log_probs))

        max_logp_given_successor = jnp.where(t <= length, jnp.max(logp,
                                                                  axis=-2),
                                             prev_logp)
        most_likely_given_successor = jnp.where(t <= length,
                                                jnp.argmax(logp, axis=-2), -1)

        return max_logp_given_successor, most_likely_given_successor

    ts = jnp.arange(1, seq_len)
    final_log_prob, most_likely_sources = lax.scan(viterbi_forward,
                                                   first_log_prob, ts)

    most_likely_initial_given_successor = jnp.argmax(trans_log_probs +
                                                     first_log_prob,
                                                     axis=-2)

    most_likely_sources = jnp.concatenate([
        jnp.expand_dims(most_likely_initial_given_successor, axis=0),
        most_likely_sources
    ],
                                          axis=0)

    def viterbi_backward(state, t):
        state = jnp.where(
            t <= length,
            jnp.sum(most_likely_sources[t] * one_hot(state, n_states)).astype(
                jnp.int64), state)
        most_likely = jnp.where(t <= length, state, -1)
        return state, most_likely

    final_state = jnp.argmax(final_log_prob)
    _, most_likely_path = lax.scan(viterbi_backward,
                                   final_state,
                                   ts,
                                   reverse=True)

    final_state = jnp.where(length == seq_len, final_state, -1)

    return jnp.append(most_likely_path, final_state)