コード例 #1
0
    def testing():

        with markov():
            v1 = to_data(
                Tensor(jnp.ones(2), OrderedDict([("1", Bint[2])]), "real"))
            print(1, v1.shape)  # shapes should alternate
            assert v1.shape == (2, )

            with markov():
                v2 = to_data(
                    Tensor(jnp.ones(2), OrderedDict([("2", Bint[2])]), "real"))
                print(2, v2.shape)  # shapes should alternate
                assert v2.shape == (2, 1)

                with markov():
                    v3 = to_data(
                        Tensor(jnp.ones(2), OrderedDict([("3", Bint[2])]),
                               "real"))
                    print(3, v3.shape)  # shapes should alternate
                    assert v3.shape == (2, )

                    with markov():
                        v4 = to_data(
                            Tensor(jnp.ones(2), OrderedDict([("4", Bint[2])]),
                                   "real"))
                        print(4, v4.shape)  # shapes should alternate

                        assert v4.shape == (2, 1)
コード例 #2
0
    def model(data_x, data_w):
        x = w = 0
        for i, y in markov(enumerate(data_x)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y)

        for i, y in markov(enumerate(data_w)):
            w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y)
コード例 #3
0
 def model(data):
     x = None
     for i, y in markov(enumerate(data)):
         probs = init_probs if x is None else transition_probs[x]
         x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
         numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)
     return x
コード例 #4
0
 def model(data):
     x = None
     D_plate = numpyro.plate("D", D, dim=-1)
     for i, y in markov(enumerate(data)):
         with D_plate:
             probs = init_probs if x is None else transition_probs[x]
             x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
             numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)
コード例 #5
0
 def testing():
     for i in markov(range(12)):
         if i % 4 == 0:
             v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real'))
             fv2 = to_funsor(v2, reals())
             assert v2.shape == (2,)
             print('a', v2.shape)
             print('a', fv2.inputs)
コード例 #6
0
 def model(data):
     x = 0
     D_plate = numpyro.plate("D", D, dim=-1)
     for i, y in markov(enumerate(data)):
         probs = transition_probs[x]
         x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
         with D_plate:
             w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6))
             numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
コード例 #7
0
 def testing():
     for i in markov(range(12)):
         if i % 4 == 0:
             v2 = to_data(
                 Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]),
                        "real"))
             fv2 = to_funsor(v2, Real)
             assert v2.shape == (2, )
             print("a", v2.shape)
             print("a", fv2.inputs)
コード例 #8
0
 def model(data1, data2):
     x = None
     D1_plate = numpyro.plate("D1", D1, dim=-1)
     D2_plate = numpyro.plate("D2", D2, dim=-1)
     for i, (y1, y2) in markov(enumerate(zip(data1, data2))):
         probs = init_probs if x is None else transition_probs[x]
         x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
         with D1_plate:
             numpyro.sample(f"y1_{i}", dist.Normal(locs[x], 1), obs=y1)
         with D2_plate:
             numpyro.sample(f"y2_{i}", dist.Normal(locs[x], 1), obs=y2)
コード例 #9
0
    def model(data):
        with numpyro.plate("states", dim):
            transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
            emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
            emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))

        trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
        for t, y in markov(enumerate(data)):
            x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
            numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
            trans_prob = transition[x]
コード例 #10
0
 def hmm(data, hidden_dim=10):
     transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
     means = jnp.arange(float(hidden_dim))
     states = [0]
     for t in markov(range(len(data))):
         states.append(
             numpyro.sample("states_{}".format(t),
                            dist.Categorical(transition[states[-1]])))
         data[t] = numpyro.sample("obs_{}".format(t),
                                  dist.Normal(means[states[-1]], 1.0),
                                  obs=data[t])
     return states, data
コード例 #11
0
 def model():
     p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
     q = numpyro.param("q", 0.25 * jnp.ones(2))
     z = numpyro.sample("z", dist.Bernoulli(0.5))
     x_prev = 0
     x_curr = 0
     for t in markov(range(T), history=history):
         probs = p[x_prev, x_curr, z]
         x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t),
                                                 dist.Bernoulli(probs))
         numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0)
     return x_prev, x_curr
コード例 #12
0
 def testing():
     for i in markov(range(5)):
         v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), bint(2))]), 'real'))
         v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real'))
         fv1 = to_funsor(v1, reals())
         fv2 = to_funsor(v2, reals())
         print(i, v1.shape)  # shapes should alternate
         if i % 2 == 0:
             assert v1.shape == (2,)
         else:
             assert v1.shape == (2, 1, 1)
         assert v2.shape == (2, 1)
         print(i, fv1.inputs)
         print('a', v2.shape)  # shapes should stay the same
         print('a', fv2.inputs)
コード例 #13
0
 def testing():
     for i in markov(range(5)):
         v1 = to_data(
             Tensor(jnp.ones(2), OrderedDict([(str(i), Bint[2])]), "real"))
         v2 = to_data(
             Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]), "real"))
         fv1 = to_funsor(v1, Real)
         fv2 = to_funsor(v2, Real)
         print(i, v1.shape)  # shapes should alternate
         if i % 2 == 0:
             assert v1.shape == (2, )
         else:
             assert v1.shape == (2, 1, 1)
         assert v2.shape == (2, 1)
         print(i, fv1.inputs)
         print("a", v2.shape)  # shapes should stay the same
         print("a", fv2.inputs)
コード例 #14
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(
            ), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [
                jnp.shape(x) for x in tree_flatten(new_carry)[0]
            ]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
コード例 #15
0
def scan_enum(
    f,
    init,
    xs,
    length,
    reverse,
    rng_key=None,
    substitute_stack=None,
    history=1,
    first_available_dim=None,
):
    from numpyro.contrib.funsor import (
        config_enumerate,
        enum,
        markov,
        trace as packed_trace,
    )

    # amount number of steps to unroll
    history = min(history, length)
    unroll_steps = min(2 * history - 1, length)
    if reverse:
        x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs)
        xs_ = tree_map(lambda x: x[:-unroll_steps], xs)
    else:
        x0 = tree_map(lambda x: x[:unroll_steps], xs)
        xs_ = tree_map(lambda x: x[unroll_steps:], xs)

    carry_shapes = []

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i)
                        and i in range(unroll_steps)) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        # we need to tell unconstrained messenger in potential energy computation
        # that only the item at time `i` is needed when transforming
        fn = handlers.infer_config(
            f, config_fn=lambda msg: {"_scan_current_index": i})

        seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == "condition":
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == "substitute":
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            # handler the name to match the pattern of sakkar_bilmes product
            with handlers.scope(prefix="_PREV_" * (unroll_steps - i),
                                divider=""):
                new_carry, y = config_enumerate(seeded_fn)(carry, x)
                trace = {}
        else:
            # Like scan_wrapper, we collect the trace of scan's transition function
            # `seeded_fn` here. To put time dimension to the correct position, we need to
            # promote shapes to make `fn` and `value`
            # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
            # and value's batch_shape is (3,), then we promote shape of
            # value so that its batch shape is (1, 3)).
            # Here we will promote `fn` shape first. `value` shape will be promoted after scanned.
            # We don't promote `value` shape here because we need to store carry shape
            # at this step. If we reshape the `value` here, output carry might get wrong shape.
            with _promote_fn_shapes(), packed_trace() as trace:
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            if len(carry_shapes) < (history + 1):
                carry_shapes.append(
                    [jnp.shape(x) for x in tree_flatten(new_carry)[0]])
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

    with handlers.block(
            hide_fn=lambda site: not site["name"].startswith("_PREV_")), enum(
                first_available_dim=first_available_dim):
        wrapped_carry = (0, rng_key, init)
        y0s = []
        # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`
        for i in markov(range(unroll_steps + 1), history=history):
            if i < unroll_steps:
                wrapped_carry, (_, y0) = body_fn(wrapped_carry,
                                                 tree_map(lambda z: z[i], x0))
                if i > 0:
                    # reshape y1, y2,... to have the same shape as y0
                    y0 = tree_multimap(
                        lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0],
                        y0)
                y0s.append(y0)
                # shapes of the first `history - 1` steps are not useful to interpret the last carry
                # shape so we don't need to record them here
                if (i >= history - 1) and (len(carry_shapes) < history + 1):
                    carry_shapes.append(
                        jnp.shape(x)
                        for x in tree_flatten(wrapped_carry[-1])[0])
            else:
                # this is the last rolling step
                y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s)
                # return early if length = unroll_steps
                if length == unroll_steps:
                    return wrapped_carry, (PytreeTrace({}), y0s)
                wrapped_carry = device_put(wrapped_carry)
                wrapped_carry, (pytree_trace,
                                ys) = lax.scan(body_fn, wrapped_carry, xs_,
                                               length - unroll_steps, reverse)

    first_var = None
    for name, site in pytree_trace.trace.items():
        # currently, we only record sample or deterministic in the trace
        # we don't need to adjust `dim_to_name` for deterministic site
        if site["type"] not in ("sample", ):
            continue
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name

        # we haven't promote shapes of values yet during `lax.scan`, so we do it here
        site["value"] = _promote_scanned_value_shapes(site["value"],
                                                      site["fn"])

        # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
        # we don't record 1-size dimensions in this field
        time_dim = -min(len(site["fn"].batch_shape),
                        jnp.ndim(site["value"]) - site["fn"].event_dim)
        site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(
        lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys)
    # then join with y0s
    ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
    # we also need to reshape `carry` to match sequential behavior
    i = (length + 1) % (history + 1)
    t, rng_key, carry = wrapped_carry
    carry_shape = carry_shapes[i]
    flatten_carry, treedef = tree_flatten(carry)
    flatten_carry = [
        jnp.reshape(x, t1_shape)
        for x, t1_shape in zip(flatten_carry, carry_shape)
    ]
    carry = tree_unflatten(treedef, flatten_carry)
    wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)
コード例 #16
0
 def model(data):
     x = w = 0
     for i, y in markov(enumerate(data)):
         x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
         w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w]))
         numpyro.sample(f"y_{i}", dist.Normal(locs[w, x], 1), obs=y)
コード例 #17
0
ファイル: scan.py プロジェクト: mhashemi0873/numpyro
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None):
    from numpyro.contrib.funsor import config_enumerate, enum, markov
    from numpyro.contrib.funsor import trace as packed_trace

    # XXX: This implementation only works for history size=1 but can be
    # extended to history size > 1 by running `f` `history_size` times
    # for initialization. However, `sequential_sum_product` does not
    # support history size > 1, so we skip supporting it here.
    # Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1.
    if reverse:
        x0 = tree_map(lambda x: x[-1], xs)
        xs_ = tree_map(lambda x: x[:-1], xs)
    else:
        x0 = tree_map(lambda x: x[0], xs)
        xs_ = tree_map(lambda x: x[1:], xs)

    carry_shape_at_t1 = None

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)

    with markov():
        wrapped_carry = (0, rng_key, init)
        wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0)
        if length == 1:
            ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0)
            return wrapped_carry, (PytreeTrace({}), ys)
        wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse)

    first_var = None
    for name, site in pytree_trace.trace.items():
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name
        leftmost_dim = min(site['infer']['dim_to_name'])
        site['infer']['dim_to_name'][leftmost_dim - 1] = '_time_{}'.format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys)
    # we also need to reshape `carry` to match sequential behavior
    if length % 2 == 0:
        t, rng_key, carry = wrapped_carry
        flatten_carry, treedef = tree_flatten(carry)
        flatten_carry = [jnp.reshape(x, t1_shape)
                         for x, t1_shape in zip(flatten_carry, carry_shape_at_t1)]
        carry = tree_unflatten(treedef, flatten_carry)
        wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)
コード例 #18
0
 def model(data):
     w = numpyro.sample("w", dist.Bernoulli(0.6))
     x = 0
     for i, y in markov(enumerate(data)):
         x = numpyro.sample(f"x_{i}", dist.Categorical(probs[w, x]))
         numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)