def model(batch, subsample, full_size): drift = numpyro.sample("drift", dist.LogNormal(-1, 0.5)) with handlers.substitute(data={"data": subsample}): plate = numpyro.plate("data", full_size, subsample_size=len(subsample)) assert plate.size == 50 def transition_fn(z_prev, y_curr): with plate: z_curr = numpyro.sample("state", dist.Normal(z_prev, drift)) y_curr = numpyro.sample("obs", dist.Bernoulli(logits=z_curr), obs=y_curr) return z_curr, y_curr _, result = scan(transition_fn, jnp.zeros(len(subsample)), batch, length=num_time_steps) return result
def model_1(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1) ) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) def transition_fn(carry, y): x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) with numpyro.plate("tones", data_dim, dim=-1): numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y) return (x, t + 1), None x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) # NB swapaxes: we move time dimension of `sequences` to the front to scan over it scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
def model_6(sequences, lengths, args, include_prior=False): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).expand( [args.hidden_dim, args.hidden_dim]).to_event(2), ) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) def transition_fn(carry, y): x_prev, x_curr, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, numpyro.sample( "x", dist.Categorical(probs_x_t), infer={"enumerate": "parallel"}) with numpyro.plate("tones", data_dim, dim=-1): probs_y_t = probs_y[x_curr.squeeze(-1)] numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y) return (x_prev, x_curr, t + 1), None x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (x_prev, x_curr, 0), jnp.swapaxes(sequences, 0, 1), history=2)
def model_2(capture_history, sex): N, T = capture_history.shape rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability def transition_fn(carry, y): first_capture_mask, z = carry # note that phi_t needs to be outside the plate, since # phi_t is shared across all N individuals phi_t = numpyro.sample("phi", dist.Uniform(0.0, 1.0)) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample( "z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)), infer={"enumerate": "parallel"}, ) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y ) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None z = jnp.ones(N, dtype=jnp.int32) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = capture_history[:, 0].astype(bool) # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it scan( transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1), )
def model_5(capture_history, sex): N, T = capture_history.shape # phi_beta controls the survival probability differential # for males versus females (in logit space) phi_beta = numpyro.sample("phi_beta", dist.Normal(0.0, 10.0)) phi_beta = sex * phi_beta rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability def transition_fn(carry, y): first_capture_mask, z = carry phi_gamma_t = numpyro.sample("phi_gamma", dist.Normal(0.0, 10.0)) phi_t = expit(phi_beta + phi_gamma_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y ) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None z = jnp.ones(N, dtype=jnp.int32) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = capture_history[:, 0].astype(bool) # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it scan( transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1), )
def prefdyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape c0 = beliefs[-1] lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.)) def transition_fn(carry, t): lam_prev = carry U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return lam_next, None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, lam_start, jnp.arange(T))
def model(T=10, q=1, r=1, phi=0.0, beta=0.0): def transition(state, i): x0, mu0 = state x1 = numpyro.sample("x", dist.Normal(phi * x0, q)) mu1 = beta * mu0 + x1 y1 = numpyro.sample("y", dist.Normal(mu1, r)) numpyro.deterministic("y2", y1 * 2) return (x1, mu1), (x1, y1) mu0 = x0 = numpyro.sample("x_0", dist.Normal(0, q)) y0 = numpyro.sample("y_0", dist.Normal(mu0, r)) _, xy = scan(transition, (x0, mu0), jnp.arange(T)) x, y = xy return jnp.append(x0, x), jnp.append(y0, y)
def fun_model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) def transition_fn(carry, y): x_prev, x_curr = carry probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs)) numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y) return (x_prev, x_curr), None (x_prev, x_curr), _ = scan(transition_fn, (0, 0), jnp.zeros(T), history=history) return x_prev, x_curr