def model_4(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape hidden_dim = int(args.hidden_dim**0.5) # split between w and x with mask(mask=include_prior): probs_w = numpyro.sample( "probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)) probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).expand_by( [hidden_dim]).to_event(2), ) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), ) def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample( "x", dist.Categorical(Vindex(probs_x)[w, x_prev])) with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
def guide(): loc = numpyro.param("loc", np.zeros(3)) cov = numpyro.param("cov", np.eye(3), constraint=constraints.positive_definite) x = numpyro.sample("x", dist.MultivariateNormal(loc, cov)) with numpyro.plate("plate", len(data)): with handlers.mask(mask=np.invert(mask)): numpyro.sample("y_unobserved", dist.MultivariateNormal(x, np.eye(3)))
def model(data, mask): with numpyro.plate('N', N): x = numpyro.sample('x', dist.Normal(0, 1)) with handlers.mask(mask=mask): numpyro.sample('y', dist.Delta(x, log_density=1.)) with handlers.scale(scale=2): numpyro.sample('obs', dist.Normal(x, 1), obs=data)
def model(data, mask): with numpyro.plate("N", N): x = numpyro.sample("x", dist.Normal(0, 1)) with handlers.mask(mask=mask): numpyro.sample("y", dist.Delta(x, log_density=1.0)) with handlers.scale(scale=2): numpyro.sample("obs", dist.Normal(x, 1), obs=data)
def model_6(sequences, lengths, args, include_prior=False): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).expand( [args.hidden_dim, args.hidden_dim]).to_event(2), ) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) def transition_fn(carry, y): x_prev, x_curr, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, numpyro.sample( "x", dist.Categorical(probs_x_t)) with numpyro.plate("tones", data_dim, dim=-1): probs_y_t = probs_y[x_curr.squeeze(-1)] numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y) return (x_prev, x_curr, t + 1), None x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (x_prev, x_curr, 0), jnp.swapaxes(sequences, 0, 1), history=2)
def model_1(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) def transition_fn(carry, y): x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) with numpyro.plate("tones", data_dim, dim=-1): numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y) return (x, t + 1), None x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) # NB swapaxes: we move time dimension of `sequences` to the front to scan over it scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
def model_3(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape hidden_dim = int(args.hidden_dim**0.5) # split between w and x with mask(mask=include_prior): probs_w = numpyro.sample( "probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)) probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3), ) def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
def guide(): loc = numpyro.param("loc", np.zeros(())) scale = numpyro.param("scale", np.ones(()), constraint=constraints.positive) x = numpyro.sample("x", dist.Normal(loc, scale)) with numpyro.plate("plate", len(data)): with handlers.mask(mask=np.invert(mask)): numpyro.sample("y_unobserved", dist.Normal(x, 1.0))
def transition_fn(carry, y): x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) with numpyro.plate("tones", data_dim, dim=-1): numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y) return (x, t + 1), None
def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample("x", dist.Categorical(Vindex(probs_x)[w, x_prev])) with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None
def model_bpinn(p, t, Y, F, data_type, D_H, u_sigma=None, f_sigma=None, sigma_w=1): m = 0.15 d = 0.15 B = 0.2 D_X, D_Y = 2, 1 # sample first layer w1 = numpyro.sample( "w1", dist.Normal(jnp.zeros((D_X, D_H)), sigma_w * jnp.ones( (D_X, D_H)))) b1 = numpyro.sample( "b1", dist.Normal(jnp.zeros((D_H, 1)), sigma_w * jnp.ones((D_H, 1)))) # sample second layer w2 = numpyro.sample( "w2", dist.Normal(jnp.zeros((D_H, D_H)), sigma_w * jnp.ones( (D_H, D_H)))) b2 = numpyro.sample( "b2", dist.Normal(jnp.zeros((D_H, 1)), sigma_w * jnp.ones((D_H, 1)))) # sample final layer w3 = numpyro.sample( "w3", dist.Normal(jnp.zeros((D_H, D_Y)), sigma_w * jnp.ones( (D_H, D_Y)))) b3 = numpyro.sample( "b3", dist.Normal(jnp.zeros((D_Y, 1)), sigma_w * jnp.ones((D_Y, 1)))) u_mu, dudt = mu_grad(p, t, w1, b1, w2, b2, w3, b3) dudtt = second_grad(p, t, w1, b1, w2, b2, w3, b3) # prior on the observation noise if u_sigma is None: prec_u = numpyro.sample("prec_u", dist.Gamma(3.0, 1.0)) u_sigma = 1.0 / jnp.sqrt(prec_u) if f_sigma is None: prec_f = numpyro.sample("prec_f", dist.Gamma(3.0, 1.0)) f_sigma = 1.0 / jnp.sqrt(prec_f) # observe data with numpyro.plate('observations', p.shape[0]): with handlers.mask(mask=data_type): u_hat = numpyro.sample("Y", dist.Normal(u_mu, u_sigma), obs=Y) f_mu = m * dudtt + d * dudt + B * jnp.sin( u_mu) - p # Forcing physics-term, always=0 f_hat = numpyro.sample("F", dist.Normal(f_mu, f_sigma), obs=F) return u_mu, f_mu
def transition_fn(carry, y): x_prev, x_curr, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, numpyro.sample( "x", dist.Categorical(probs_x_t)) with numpyro.plate("tones", data_dim, dim=-1): probs_y_t = probs_y[x_curr.squeeze(-1)] numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y) return (x_prev, x_curr, t + 1), None
def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None
def transition_fn(carry, y): first_capture_mask, z = carry with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None
def transition_fn( carry: Tuple[jnp.ndarray, jnp.ndarray], y: jnp.ndarray ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: """One time step funciton.""" x_prev, t = carry with numpyro.plate("sequence", batch, dim=-2): with mask(mask=(t < lengths)[..., None]): x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) with numpyro.plate("tones", data_dim, dim=-1): numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y) return (x, t + 1), None
def transition_fn(carry, y): first_capture_mask, z = carry with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}): phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma)) phi_t = expit(phi_logit_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None
def transition_fn(carry, y): first_capture_mask, z = carry # note that phi_t needs to be outside the plate, since # phi_t is shared across all N individuals phi_t = numpyro.sample("phi", dist.Uniform(0.0, 1.0)) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None
def test_get_mask_optimization(): def model(): with numpyro.handlers.seed(rng_seed=0): x = numpyro.sample("x", dist.Normal(0, 1)) numpyro.sample("y", dist.Normal(x, 1), obs=0.) called.add("model-always") if numpyro.get_mask() is not False: called.add("model-sometimes") numpyro.factor("f", x + 1) def guide(): with numpyro.handlers.seed(rng_seed=1): x = numpyro.sample("x", dist.Normal(0, 1)) called.add("guide-always") if numpyro.get_mask() is not False: called.add("guide-sometimes") numpyro.factor("g", 2 - x) called = set() trace = handlers.trace(guide).get_trace() handlers.replay(model, trace)() assert "model-always" in called assert "guide-always" in called assert "model-sometimes" in called assert "guide-sometimes" in called called = set() with handlers.mask(mask=False): trace = handlers.trace(guide).get_trace() handlers.replay(model, trace)() assert "model-always" in called assert "guide-always" in called assert "model-sometimes" not in called assert "guide-sometimes" not in called called = set() Predictive(model, guide=guide, num_samples=2, parallel=True)(random.PRNGKey(2)) assert "model-always" in called assert "guide-always" in called assert "model-sometimes" not in called assert "guide-sometimes" not in called
def transition_fn(carry, y): first_capture_mask, z = carry phi_gamma_t = numpyro.sample("phi_gamma", dist.Normal(0.0, 10.0)) phi_t = expit(phi_beta + phi_gamma_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample( "z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)), infer={"enumerate": "parallel"}, ) mu_y_t = rho * z numpyro.sample( "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y ) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None
def masked_model(x, y, data_type): with numpyro.plate('data'): with handlers.mask(mask=data_type): Y = numpyro.sample("Y", dist.Normal(x, 1.), obs=y) return Y
def model(): with handlers.mask(mask=jnp.zeros(10, dtype=bool)): numpyro.factor('inf', -jnp.inf)
def model(z=None): p = numpyro.param("p", np.array([0.75, 0.25])) z = numpyro.sample("z", dist.Categorical(p), obs=z) logger.info("z.shape = {}".format(z.shape)) with numpyro.plate("data", 3), handlers.mask(mask=mask): numpyro.sample("x", dist.Normal(z, 1.0), obs=data)