def testing(): with markov(): v1 = to_data( Tensor(jnp.ones(2), OrderedDict([("1", Bint[2])]), "real")) print(1, v1.shape) # shapes should alternate assert v1.shape == (2, ) with markov(): v2 = to_data( Tensor(jnp.ones(2), OrderedDict([("2", Bint[2])]), "real")) print(2, v2.shape) # shapes should alternate assert v2.shape == (2, 1) with markov(): v3 = to_data( Tensor(jnp.ones(2), OrderedDict([("3", Bint[2])]), "real")) print(3, v3.shape) # shapes should alternate assert v3.shape == (2, ) with markov(): v4 = to_data( Tensor(jnp.ones(2), OrderedDict([("4", Bint[2])]), "real")) print(4, v4.shape) # shapes should alternate assert v4.shape == (2, 1)
def model(data_x, data_w): x = w = 0 for i, y in markov(enumerate(data_x)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y) for i, y in markov(enumerate(data_w)): w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y)
def model(data): x = None for i, y in markov(enumerate(data)): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y) return x
def model(data): x = None D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): with D_plate: probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)
def testing(): for i in markov(range(12)): if i % 4 == 0: v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real')) fv2 = to_funsor(v2, reals()) assert v2.shape == (2,) print('a', v2.shape) print('a', fv2.inputs)
def model(data): x = 0 D_plate = numpyro.plate("D", D, dim=-1) for i, y in markov(enumerate(data)): probs = transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D_plate: w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6)) numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
def testing(): for i in markov(range(12)): if i % 4 == 0: v2 = to_data( Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]), "real")) fv2 = to_funsor(v2, Real) assert v2.shape == (2, ) print("a", v2.shape) print("a", fv2.inputs)
def model(data1, data2): x = None D1_plate = numpyro.plate("D1", D1, dim=-1) D2_plate = numpyro.plate("D2", D2, dim=-1) for i, (y1, y2) in markov(enumerate(zip(data1, data2))): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample(f"x_{i}", dist.Categorical(probs)) with D1_plate: numpyro.sample(f"y1_{i}", dist.Normal(locs[x], 1), obs=y1) with D2_plate: numpyro.sample(f"y2_{i}", dist.Normal(locs[x], 1), obs=y2)
def model(data): with numpyro.plate("states", dim): transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim))) emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1)) emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1)) trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim))) for t, y in markov(enumerate(data)): x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob)) numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y) trans_prob = transition[x]
def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) means = jnp.arange(float(hidden_dim)) states = [0] for t in markov(range(len(data))): states.append( numpyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) data[t] = numpyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t]) return states, data
def model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) x_prev = 0 x_curr = 0 for t in markov(range(T), history=history): probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t), dist.Bernoulli(probs)) numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0) return x_prev, x_curr
def testing(): for i in markov(range(5)): v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), bint(2))]), 'real')) v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real')) fv1 = to_funsor(v1, reals()) fv2 = to_funsor(v2, reals()) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2,) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print('a', v2.shape) # shapes should stay the same print('a', fv2.inputs)
def testing(): for i in markov(range(5)): v1 = to_data( Tensor(jnp.ones(2), OrderedDict([(str(i), Bint[2])]), "real")) v2 = to_data( Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]), "real")) fv1 = to_funsor(v1, Real) fv2 = to_funsor(v2, Real) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2, ) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print("a", v2.shape) # shapes should stay the same print("a", fv2.inputs)
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes( ), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [ jnp.shape(x) for x in tree_flatten(new_carry)[0] ] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
def scan_enum( f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1, first_available_dim=None, ): from numpyro.contrib.funsor import ( config_enumerate, enum, markov, trace as packed_trace, ) # amount number of steps to unroll history = min(history, length) unroll_steps = min(2 * history - 1, length) if reverse: x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs) xs_ = tree_map(lambda x: x[:-unroll_steps], xs) else: x0 = tree_map(lambda x: x[:unroll_steps], xs) xs_ = tree_map(lambda x: x[unroll_steps:], xs) carry_shapes = [] def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append( [jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) with handlers.block( hide_fn=lambda site: not site["name"].startswith("_PREV_")), enum( first_available_dim=first_available_dim): wrapped_carry = (0, rng_key, init) y0s = [] # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan` for i in markov(range(unroll_steps + 1), history=history): if i < unroll_steps: wrapped_carry, (_, y0) = body_fn(wrapped_carry, tree_map(lambda z: z[i], x0)) if i > 0: # reshape y1, y2,... to have the same shape as y0 y0 = tree_multimap( lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0) y0s.append(y0) # shapes of the first `history - 1` steps are not useful to interpret the last carry # shape so we don't need to record them here if (i >= history - 1) and (len(carry_shapes) < history + 1): carry_shapes.append( jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]) else: # this is the last rolling step y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s) # return early if length = unroll_steps if length == unroll_steps: return wrapped_carry, (PytreeTrace({}), y0s) wrapped_carry = device_put(wrapped_carry) wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - unroll_steps, reverse) first_var = None for name, site in pytree_trace.trace.items(): # currently, we only record sample or deterministic in the trace # we don't need to adjust `dim_to_name` for deterministic site if site["type"] not in ("sample", ): continue # add `time` dimension, the name will be '_time_{first variable in the trace}' if first_var is None: first_var = name # we haven't promote shapes of values yet during `lax.scan`, so we do it here site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because # we don't record 1-size dimensions in this field time_dim = -min(len(site["fn"].batch_shape), jnp.ndim(site["value"]) - site["fn"].event_dim) site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var) # similar to carry, we need to reshape due to shape alternating in markov ys = tree_multimap( lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys) # then join with y0s ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) # we also need to reshape `carry` to match sequential behavior i = (length + 1) % (history + 1) t, rng_key, carry = wrapped_carry carry_shape = carry_shapes[i] flatten_carry, treedef = tree_flatten(carry) flatten_carry = [ jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape) ] carry = tree_unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys)
def model(data): x = w = 0 for i, y in markov(enumerate(data)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x])) w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w])) numpyro.sample(f"y_{i}", dist.Normal(locs[w, x], 1), obs=y)
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None): from numpyro.contrib.funsor import config_enumerate, enum, markov from numpyro.contrib.funsor import trace as packed_trace # XXX: This implementation only works for history size=1 but can be # extended to history size > 1 by running `f` `history_size` times # for initialization. However, `sequential_sum_product` does not # support history size > 1, so we skip supporting it here. # Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1. if reverse: x0 = tree_map(lambda x: x[-1], xs) xs_ = tree_map(lambda x: x[:-1], xs) else: x0 = tree_map(lambda x: x[0], xs) xs_ = tree_map(lambda x: x[1:], xs) carry_shape_at_t1 = None def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y) with markov(): wrapped_carry = (0, rng_key, init) wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0) if length == 1: ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0) return wrapped_carry, (PytreeTrace({}), ys) wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse) first_var = None for name, site in pytree_trace.trace.items(): # add `time` dimension, the name will be '_time_{first variable in the trace}' if first_var is None: first_var = name leftmost_dim = min(site['infer']['dim_to_name']) site['infer']['dim_to_name'][leftmost_dim - 1] = '_time_{}'.format(first_var) # similar to carry, we need to reshape due to shape alternating in markov ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys) # we also need to reshape `carry` to match sequential behavior if length % 2 == 0: t, rng_key, carry = wrapped_carry flatten_carry, treedef = tree_flatten(carry) flatten_carry = [jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape_at_t1)] carry = tree_unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys)
def model(data): w = numpyro.sample("w", dist.Bernoulli(0.6)) x = 0 for i, y in markov(enumerate(data)): x = numpyro.sample(f"x_{i}", dist.Categorical(probs[w, x])) numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)