def semi_supervised_hmm(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words): num_categories, num_words = transition_prior.shape[ 0], emission_prior.shape[0] transition_prob = numpyro.sample( 'transition_prob', dist.Dirichlet( np.broadcast_to(transition_prior, (num_categories, num_categories)))) emission_prob = numpyro.sample( 'emission_prob', dist.Dirichlet( np.broadcast_to(emission_prior, (num_categories, num_words)))) # models supervised data; # here we don't make any assumption about the first supervised category, in other words, # we place a flat/uniform prior on it. numpyro.sample('supervised_categories', dist.Categorical( transition_prob[supervised_categories[:-1]]), obs=supervised_categories[1:]) numpyro.sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]), obs=supervised_words) # computes log prob of unsupervised data transition_log_prob = np.log(transition_prob) emission_log_prob = np.log(emission_prob) init_log_prob = emission_log_prob[:, unsupervised_words[0]] log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:], transition_log_prob, emission_log_prob) log_prob = logsumexp(log_prob, axis=0, keepdims=True) # inject log_prob to potential function numpyro.factor('forward_log_prob', log_prob)
def dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 2 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones(num_classes))) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"}) # here we use Vindex to allow broadcasting for the second index `c` # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex with numpyro.plate("position", num_positions): numpyro.sample("y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations)
def nondynamic_mixture_model(beliefs, y, mask): M, T, _ = beliefs[0].shape tau = .5 weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (M, ) gamma = npyro.sample('gamma', dist.InverseGamma(2., 5.)) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) U = jnp.log(p0) def transition_fn(carry, t): logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs).mask(mask[t]) npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist)) return None, None with npyro.handlers.condition(data={"y": y}): scan(transition_fn, None, jnp.arange(T))
def model_3(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape hidden_dim = int(args.hidden_dim**0.5) # split between w and x with mask(mask=include_prior): probs_w = numpyro.sample( "probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)) probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3), ) def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
def semi_supervised_hmm(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words): num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0] transition_prob = sample('transition_prob', dist.Dirichlet( np.broadcast_to(transition_prior, (num_categories, num_categories)))) emission_prob = sample('emission_prob', dist.Dirichlet( np.broadcast_to(emission_prior, (num_categories, num_words)))) # models supervised data; # here we don't make any assumption about the first supervised category, in other words, # we place a flat/uniform prior on it. sample('supervised_categories', dist.Categorical(transition_prob[supervised_categories[:-1]]), obs=supervised_categories[1:]) sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]), obs=supervised_words) # computes log prob of unsupervised data transition_log_prob = np.log(transition_prob) emission_log_prob = np.log(emission_prob) init_log_prob = emission_log_prob[:, unsupervised_words[0]] log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:], transition_log_prob, emission_log_prob) log_prob = logsumexp(log_prob, axis=0, keepdims=True) # inject log_prob to potential function # NB: This is a trick (uses an invalid value `0` of Multinomial distribution) to add an # additional term to potential energy. return sample('forward_log_prob', dist.Multinomial(logits=-log_prob), obs=0)
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data): rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3) transition_prior = np.ones(num_categories) emission_prior = np.repeat(0.1, num_words) transition_prob = dist.Dirichlet(transition_prior).sample( key=rng_key_transition, sample_shape=(num_categories, )) emission_prob = dist.Dirichlet(emission_prior).sample( key=rng_key_emission, sample_shape=(num_categories, )) start_prob = np.repeat(1. / num_categories, num_categories) categories, words = [], [] for t in range(num_supervised_data + num_unsupervised_data): rng_key, rng_key_transition, rng_key_emission = random.split( rng_key, 3) if t == 0 or t == num_supervised_data: category = dist.Categorical(start_prob).sample( key=rng_key_transition) else: category = dist.Categorical( transition_prob[category]).sample(key=rng_key_transition) word = dist.Categorical( emission_prob[category]).sample(key=rng_key_emission) categories.append(category) words.append(word) # split into supervised data and unsupervised data categories, words = np.stack(categories), np.stack(words) supervised_categories = categories[:num_supervised_data] supervised_words = words[:num_supervised_data] unsupervised_words = words[num_supervised_data:] return (transition_prior, emission_prior, transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words)
def model_4(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape hidden_dim = int(args.hidden_dim**0.5) # split between w and x with mask(mask=include_prior): probs_w = numpyro.sample( "probs_w", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)) probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).expand_by( [hidden_dim]).to_event(2), ) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), ) def transition_fn(carry, y): w_prev, x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): w = numpyro.sample("w", dist.Categorical(probs_w[w_prev])) x = numpyro.sample( "x", dist.Categorical(Vindex(probs_x)[w, x_prev])) with numpyro.plate("tones", data_dim, dim=-1) as tones: numpyro.sample("y", dist.Bernoulli(probs_y[w, x, tones]), obs=y) return (w, x, t + 1), None w_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))
def model(data): with numpyro.plate("states", dim): transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim))) emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1)) emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1)) trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim))) for t, y in markov(enumerate(data)): x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob)) numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y) trans_prob = transition[x]
def simulate_data( rng_key: np.ndarray, num_categories: int, num_words: int, num_supervised: int, num_unsupservised: int, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3) transition_prior = jnp.ones(num_categories) emission_prior = jnp.repeat(0.1, num_words) transition_prob = dist.Dirichlet(transition_prior).sample( rng_key_transition, sample_shape=(num_categories, )) emission_prob = dist.Dirichlet(emission_prior).sample( rng_key_emission, sample_shape=(num_categories, )) start_prob = jnp.repeat(1.0 / num_categories, num_categories) category = 0 categories = [] words = [] for t in range(num_supervised + num_unsupservised): rng_key, rng_key_transition, rng_key_emission = random.split( rng_key, 3) if t == 0 or t == num_supervised: category = dist.Categorical(start_prob).sample(rng_key_transition) else: category = dist.Categorical( transition_prob[category]).sample(rng_key_transition) word = dist.Categorical( emission_prob[category]).sample(rng_key_emission) categories.append(category) words.append(word) # Split data into supervised and unsupervised categories = jnp.stack(categories) words = jnp.stack(words) supervised_categories = categories[:num_supervised] supervised_words = words[:num_supervised] unsupervised_words = words[num_supervised:] return ( transition_prob, emission_prob, supervised_categories, supervised_words, unsupervised_words, )
def unsupervised_hmm(words): with numpyro.plate("prob_plate", num_categories): transition_prob = numpyro.sample("transition_prob", dist.Dirichlet(transition_prior)) emission_prob = numpyro.sample("emission_prob", dist.Dirichlet(emission_prior)) transition_log_prob = jnp.log(transition_prob) emission_log_prob = jnp.log(emission_prob) log_prob = emission_log_prob[:, words[0]] for t in range(1, len(words)): log_prob = forward_log_prob(log_prob, words[t], transition_log_prob, emission_log_prob) prob = jnp.exp(logsumexp(log_prob, 0)) # a trick to inject an additional log_prob into model's log_prob numpyro.sample("forward_prob", dist.Bernoulli(prob), obs=jnp.array(1.))
def gammadyn_mixture_model(beliefs, y, mask): M, T, _ = beliefs[0].shape tau = .5 weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (M, ) gamma = npyro.sample('gamma', dist.InverseGamma(2., 2.)) mu = jnp.log(jnp.exp(gamma) - 1) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) scale = npyro.sample('scale', dist.Gamma(1., 1.)) rho = npyro.sample('rho', dist.Beta(1., 2.)) sigma = jnp.sqrt(-(1 - rho**2) / (2 * jnp.log(rho))) * scale U = jnp.log(p0) def transition_fn(carry, t): x_prev = carry gamma_dyn = npyro.deterministic('gamma_dyn', nn.softplus(mu + x_prev)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma_dyn, -1), jnp.expand_dims(U, -2)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs).mask(mask[t]) npyro.sample('y', dist.MixtureSameFamily(mixing_dist, component_dist)) with npyro.handlers.reparam( config={"x_next": npyro.infer.reparam.TransformReparam()}): affine = dist.transforms.AffineTransform(rho * x_prev, sigma) x_next = npyro.sample( 'x_next', dist.TransformedDistribution(dist.Normal(0., 1.), affine)) return (x_next), None x0 = jnp.zeros(1) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (x0), jnp.arange(T))
def item_difficulty(annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): eta = numpyro.sample( "eta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample( "Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"}) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(eta[c], chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", annotations.shape[-1]): numpyro.sample("y", dist.Categorical(logits=theta), obs=annotations)
def logistic_random_effects(positions, annotations): """ This model corresponds to the plate diagram in Figure 5 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) chi = numpyro.sample("Chi", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with handlers.reparam(config={"theta": LocScaleReparam(0)}): theta = numpyro.sample("theta", dist.Normal(0, chi[c]).to_event(1)) theta = jnp.pad(theta, [(0, 0)] * (jnp.ndim(theta) - 1) + [(0, 1)]) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] - theta numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def hierarchical_dawid_skene(positions, annotations): """ This model corresponds to the plate diagram in Figure 4 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is # invariant up to a constant, so we'll follow [1]: fix the last term of `beta` # to 0 and only define hyperpriors for the first `num_classes - 1` terms. zeta = numpyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1)) omega = numpyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1)) with numpyro.plate("annotator", num_annotators, dim=-2): with numpyro.plate("class", num_classes): # non-centered parameterization with handlers.reparam(config={"beta": LocScaleReparam(0)}): beta = numpyro.sample("beta", dist.Normal(zeta, omega).to_event(1)) # pad 0 to the last item beta = jnp.pad(beta, [(0, 0)] * (jnp.ndim(beta) - 1) + [(0, 1)]) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with numpyro.plate("position", num_positions): logits = Vindex(beta)[positions, c, :] numpyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
def model_1(sequences, lengths, args, include_prior=True): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) def transition_fn(carry, y): x_prev, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) with numpyro.plate("tones", data_dim, dim=-1): numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y) return (x, t + 1), None x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32) # NB swapaxes: we move time dimension of `sequences` to the front to scan over it scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
def fulldyn_mixture_model(beliefs, y, mask): M, T, N, _ = beliefs[0].shape c0 = beliefs[-1] tau = .5 with npyro.plate('N', N): weights = npyro.sample('weights', dist.Dirichlet(tau * jnp.ones(M))) assert weights.shape == (N, M) mu = npyro.sample('mu', dist.Normal(5., 5.)) lam12 = npyro.sample('lam12', dist.HalfCauchy(1.).expand([2]).to_event(1)) lam34 = npyro.sample('lam34', dist.HalfCauchy(1.)) _lam34 = jnp.expand_dims(lam34, -1) lam0 = npyro.deterministic( 'lam0', jnp.concatenate([lam12.cumsum(-1), _lam34, _lam34], -1)) eta = npyro.sample('eta', dist.Beta(1, 10)) scale = npyro.sample('scale', dist.HalfNormal(1.)) theta = npyro.sample('theta', dist.HalfCauchy(5.)) rho = jnp.exp(-theta) sigma = jnp.sqrt((1 - rho**2) / (2 * theta)) * scale x0 = jnp.zeros(N) def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask( mask[t][..., None]) with npyro.plate('subjects', N): y = npyro.sample( 'y', dist.MixtureSameFamily(mixing_dist, component_dist)) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None lam_start = npyro.deterministic('lam_start', lam0 + jnp.expand_dims(eta, -1) * c0) with npyro.handlers.condition(data={"y": y}): scan(transition_fn, (lam_start, x0), jnp.arange(T))
def nondyn_single_model(beliefs, y, mask): T, _ = beliefs[0].shape gamma = npyro.sample('gamma', dist.InverseGamma(2., 3.)) p = npyro.sample('p', dist.Dirichlet(jnp.ones(3))) p0 = npyro.deterministic( 'p0', jnp.concatenate([ p[..., :1] / 2, p[..., :1] / 2 + p[..., 1:2], p[..., 2:] / 2, p[..., 2:] / 2 ], -1)) U = jnp.log(p0) def transition_fn(carry, t): logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) return None, None with npyro.handlers.condition(data={"y": y}): scan(transition_fn, None, jnp.arange(T))
def model_6(sequences, lengths, args, include_prior=False): num_sequences, max_length, data_dim = sequences.shape with mask(mask=include_prior): # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1).expand( [args.hidden_dim, args.hidden_dim]).to_event(2), ) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) def transition_fn(carry, y): x_prev, x_curr, t = carry with numpyro.plate("sequences", num_sequences, dim=-2): with mask(mask=(t < lengths)[..., None]): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, numpyro.sample( "x", dist.Categorical(probs_x_t)) with numpyro.plate("tones", data_dim, dim=-1): probs_y_t = probs_y[x_curr.squeeze(-1)] numpyro.sample("y", dist.Bernoulli(probs_y_t), obs=y) return (x_prev, x_curr, t + 1), None x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32) x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32) scan(transition_fn, (x_prev, x_curr, 0), jnp.swapaxes(sequences, 0, 1), history=2)
def mace(positions, annotations): """ This model corresponds to the plate diagram in Figure 3 of reference [1]. """ num_annotators = int(np.max(positions)) + 1 num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("annotator", num_annotators): epsilon = numpyro.sample("epsilon", dist.Dirichlet(jnp.full(num_classes, 10))) theta = numpyro.sample("theta", dist.Beta(0.5, 0.5)) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample( "c", dist.DiscreteUniform(0, num_classes - 1), infer={"enumerate": "parallel"}, ) with numpyro.plate("position", num_positions): s = numpyro.sample( "s", dist.Bernoulli(1 - theta[positions]), infer={"enumerate": "parallel"}, ) probs = jnp.where(s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]) numpyro.sample("y", dist.Categorical(probs), obs=annotations)
def gmm(data): mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) with numpyro.plate("num_clusters", K, dim=-1): cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.)) with numpyro.plate("data", data.shape[0], dim=-1): assignments = numpyro.sample("assignments", dist.Categorical(mix_proportions)) numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.), obs=data)
def multinomial(annotations): """ This model corresponds to the plate diagram in Figure 1 of reference [1]. """ num_classes = int(np.max(annotations)) + 1 num_items, num_positions = annotations.shape with numpyro.plate("class", num_classes): zeta = numpyro.sample("zeta", dist.Dirichlet(jnp.ones(num_classes))) pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes))) with numpyro.plate("item", num_items, dim=-2): c = numpyro.sample("c", dist.Categorical(pi)) with numpyro.plate("position", num_positions): numpyro.sample("y", dist.Categorical(zeta[c]), obs=annotations)
def create_model(yT, yC, num_components): # Cosntants nC = yC.shape[0] nT = yT.shape[0] zC = jnp.isinf(yC).sum().item() zT = jnp.isinf(yT).sum().item() yT_finite = yT[jnp.isinf(yT) == False] yC_finite = yC_finite = yC[jnp.isinf(yC) == False] K = num_components p = numpyro.sample('p', dist.Beta(.5, .5)) gammaC = numpyro.sample('gammaC', dist.Beta(1, 1)) gammaT = numpyro.sample('gammaT', dist.Beta(1, 1)) etaC = numpyro.sample('etaC', dist.Dirichlet(jnp.ones(K) / K)) etaT = numpyro.sample('etaT', dist.Dirichlet(jnp.ones(K) / K)) with numpyro.plate('mixutre_components', K): nu = numpyro.sample('nu', dist.LogNormal(3.5, 0.5)) mu = numpyro.sample('mu', dist.Normal(0, 3)) sigma = numpyro.sample('sigma', dist.LogNormal(0, .5)) phi = numpyro.sample('phi', dist.Normal(0, 3)) gammaT_star = simulate_data.compute_gamma_T_star(gammaC, gammaT, p) etaT_star = simulate_data.compute_eta_T_star(etaC, etaT, p, gammaC, gammaT, gammaT_star) with numpyro.plate('y_C', nC - zC): numpyro.sample('finite_obs_C', Mix(nu[None, :], mu[None, :], sigma[None, :], phi[None, :], etaC[None, :]), obs=yC_finite[:, None]) with numpyro.plate('y_T', nT - zT): numpyro.sample('finite_obs_T', Mix(nu[None, :], mu[None, :], sigma[None, :], phi[None, :], etaT_star[None, :]), obs=yT_finite[:, None]) numpyro.sample('N_C', dist.Binomial(nC, gammaC), obs=zC) numpyro.sample('N_T', dist.Binomial(nT, gammaT_star), obs=zT)
def gmm(data, num_components=3): mus = numpyro.sample('mus', dist.Normal(jnp.zeros(num_components), jnp.ones(num_components) * 100.).to_event(1)) sigmas = numpyro.sample('sigmas', dist.HalfNormal(jnp.ones(num_components) * 100.).to_event(1)) mixture_probs = numpyro.sample('mixture_probs', dist.Dirichlet( jnp.ones(num_components) / num_components)) with numpyro.plate('data', len(data), dim=-1): z = numpyro.sample('z', dist.Categorical(mixture_probs)) numpyro.sample('ll', dist.Normal(mus[z], sigmas[z]), obs=data)
def simplex_proposal(self, rng_key, x, grad_, hess_): max_non_diag_hess = np.max(hess_[np.logical_not( np.eye(hess_.shape[0], dtype=bool))].reshape(hess_.shape[0], -1), axis=1) concentration = 1 - x**2 * (np.diag(hess_) - max_non_diag_hess) dist_ = dist.Dirichlet(concentration=concentration + W_CORRECTION) return dist_.sample(rng_key).reshape(x.shape), dist_
def model(): a = numpyro.sample("a", dist.Normal(0, 1)) b = numpyro.sample("b", dist.Normal(a[..., None], jnp.ones(3)).to_event(1)) c = numpyro.sample( "c", dist.MultivariateNormal(jnp.zeros(3) + a[..., None], jnp.eye(3)) ) with numpyro.plate("i", 2): d = numpyro.sample("d", dist.Dirichlet(jnp.exp(b + c))) numpyro.sample("e", dist.Categorical(logits=d), obs=jnp.array([0, 0])) return a, b, c, d
def guide(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": encoder = flax_module( "encoder", FlaxEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], ), input_shape=(1, hyperparams["vocab_size"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": encoder = haiku_module( "encoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuEncoder( hyperparams["vocab_size"], hyperparams["num_topics"], hyperparams["hidden"], hyperparams["dropout_rate"], )), input_shape=(1, hyperparams["vocab_size"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) if nn_framework == "flax": concentration = encoder(batch_docs, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": concentration = encoder(numpyro.prng_key(), batch_docs, is_training) numpyro.sample("theta", dist.Dirichlet(concentration))
def semi_supervised_hmm( num_categories: int, num_words: int, supervised_categories: jnp.ndarray, supervised_words: jnp.ndarray, unsupervised_words: jnp.ndarray, ) -> None: transition_prior = jnp.ones(num_categories) emission_prior = jnp.repeat(0.1, num_words) transition_prob = numpyro.sample( "transition_prob", dist.Dirichlet( jnp.broadcast_to(transition_prior, (num_categories, num_categories))), ) emission_prob = numpyro.sample( "emission_prob", dist.Dirichlet( jnp.broadcast_to(emission_prior, (num_categories, num_words))), ) numpyro.sample( "supervised_categories", dist.Categorical(transition_prob[supervised_categories[:-1]]), obs=supervised_categories[1:], ) numpyro.sample( "supervised_words", dist.Categorical(emission_prob[supervised_categories]), obs=supervised_words, ) transition_log_prob = jnp.log(transition_prob) emission_log_prob = jnp.log(emission_prob) init_log_prob = emission_log_prob[:, unsupervised_words[0]] log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:], transition_log_prob, emission_log_prob) log_prob = logsumexp(log_prob, axis=0, keepdims=True) numpyro.factor("forward_log_prob", log_prob)
def model(docs, hyperparams, is_training=False, nn_framework="flax"): if nn_framework == "flax": decoder = flax_module( "decoder", FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]), input_shape=(1, hyperparams["num_topics"]), # ensure PRNGKey is made available to dropout layers apply_rng=["dropout"], # indicate mutable state due to BatchNorm layers mutable=["batch_stats"], # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) elif nn_framework == "haiku": decoder = haiku_module( "decoder", # use `transform_with_state` for BatchNorm hk.transform_with_state( HaikuDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"])), input_shape=(1, hyperparams["num_topics"]), apply_rng=True, # to ensure proper initialisation of BatchNorm we must # initialise with is_training=True is_training=True, ) else: raise ValueError( f"Invalid choice {nn_framework} for argument nn_framework") with numpyro.plate("documents", docs.shape[0], subsample_size=hyperparams["batch_size"]): batch_docs = numpyro.subsample(docs, event_dim=1) theta = numpyro.sample( "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"]))) if nn_framework == "flax": logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()}) elif nn_framework == "haiku": logits = decoder(numpyro.prng_key(), theta, is_training) total_count = batch_docs.sum(-1) numpyro.sample("obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs)
def model_1( sequences: Optional[np.ndarray] = None, lengths: Optional[np.ndarray] = None, hidden_dim: int = 16, batch: int = 100, seq_len: int = 0, data_dim: int = 10, future_steps: int = 0, ) -> None: if sequences is not None: assert lengths is not None batch, seq_len, data_dim = sequences.shape future = np.zeros((batch, future_steps, data_dim)) sequences = np.concatenate([sequences, future], axis=1) else: lengths = np.zeros((batch)) probs_x = numpyro.sample( "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)) probs_y = numpyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2)) def transition_fn( carry: Tuple[jnp.ndarray, jnp.ndarray], y: jnp.ndarray ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: """One time step funciton.""" x_prev, t = carry with numpyro.plate("sequence", batch, dim=-2): with mask(mask=(t < lengths)[..., None]): x = numpyro.sample("x", dist.Categorical(probs_x[x_prev])) with numpyro.plate("tones", data_dim, dim=-1): numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y) return (x, t + 1), None x_init = jnp.zeros((batch, 1), dtype=jnp.int32) if sequences is not None: # for loop with time step: data shape = (seq, batch, data_dim) scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1)) else: scan(transition_fn, (x_init, 0), None, length=seq_len + future_steps)
def guide(k, obs=None, num_obs_total=None, d=None): # the latent MixGaus distribution which learns the parameters if obs is not None: assert(jnp.ndim(obs) == 2) _, d = jnp.shape(obs) else: assert(num_obs_total is not None) assert(d is not None) alpha_log = param('alpha_log', jnp.zeros(k)) alpha = jnp.exp(alpha_log) pis = sample('pis', dist.Dirichlet(alpha)) mus_loc = param('mus_loc', jnp.zeros((k, d))) mus = sample('mus', dist.Normal(mus_loc, 1.)) sigs = sample('sigs', dist.InverseGamma(1., 1.), obs=jnp.ones_like(mus)) return pis, mus, sigs