def fun_model(T, q=1, r=1, phi=0.0, beta=0.0): def transition(state, i): x, mu = state x = numpyro.sample("x", dist.LogNormal(phi * x, q)) mu = beta * mu + x numpyro.sample("y", dist.Normal(mu, r)) return (x, mu), None scan(transition, (0.0, 0.0), jnp.arange(T))
def outer_fn(y, val): def body_fn(z, val): z = numpyro.sample("z", dist.Normal(z, 1)) return z, z y = numpyro.sample("y", dist.Normal(y, 1)) _, zs = scan(body_fn, y, None, 4) return y, zs
def multiply_and_add_repeatedly(K, c_in): def iteration(c_prev, c_in): c_next = jnp.dot(c_prev, K) + c_in return c_next, (c_next, ) _, (ys, ) = scan(iteration, init=jnp.asarray([1.0, 0.0]), xs=c_in) return ys
def model(T=10, q=1, r=1, phi=0.0, beta=0.0): def transition(state, i): x0, mu0 = state x1 = numpyro.sample("x", dist.Normal(phi * x0, q)) mu1 = beta * mu0 + x1 y1 = numpyro.sample("y", dist.Normal(mu1, r)) numpyro.deterministic("y2", y1 * 2) return (x1, mu1), (x1, y1) mu0 = x0 = numpyro.sample("x_0", dist.Normal(0, q)) y0 = numpyro.sample("y_0", dist.Normal(mu0, r)) _, xy = scan(transition, (x0, mu0), jnp.arange(T)) x, y = xy return jnp.append(x0, x), jnp.append(y0, y)