def model(z=None) -> None: batch_size = 1 if z is not None: batch_size = z.shape[0] mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1)) sigma = sample('sigma', dists.InverseGamma(1.).expand_by((2, )).to_event(1)) with plate('batch', batch_size, batch_size): sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
def model(z = None, num_obs_total = None) -> None: batch_size = 1 if z is not None: batch_size = z.shape[0] if num_obs_total is None: num_obs_total = batch_size mu = sample('mu', dists.Normal(args.prior_mu).expand_by((d,)).to_event(1)) sigma = sample('sigma', dists.InverseGamma(1.).expand_by((d,)).to_event(1)) with plate('batch', num_obs_total, batch_size): sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
def model(x_first=None, x_second=None, num_obs_total=None) -> None: batch_size = 1 if x_first is not None: batch_size = x_first.shape[0] if num_obs_total is None: num_obs_total = batch_size mu = sample('mu', dists.Normal()) sigma = sample('sigma', dists.InverseGamma(1.)) with plate('batch', num_obs_total, batch_size): sample('x_first', dists.Normal(mu, sigma), obs=x_first) sample('x_second', dists.Normal(mu, sigma), obs=x_second)
def model(z=None, z2=None, num_obs_total=None) -> None: batch_size = 1 if z is not None: batch_size = z.shape[0] assert (z.shape is not None) assert (z.shape[0] == z2.shape[0]) if num_obs_total is None: num_obs_total = batch_size mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1)) sigma = sample('sigma', dists.InverseGamma(1.).expand_by((2, )).to_event(1)) with plate('batch', num_obs_total, batch_size): sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
def model(k, obs=None, num_obs_total=None, d=None): # this is our model function using the GaussianMixture distribution # with prior belief if obs is not None: assert(jnp.ndim(obs) == 2) batch_size, d = jnp.shape(obs) else: assert(num_obs_total is not None) batch_size = num_obs_total assert(d is not None) num_obs_total = batch_size if num_obs_total is None else num_obs_total pis = sample('pis', dist.Dirichlet(jnp.ones(k))) mus = sample('mus', dist.Normal(jnp.zeros((k, d)), 10.)) sigs = sample('sigs', dist.InverseGamma(1., 1.), sample_shape=jnp.shape(mus)) with plate('batch', num_obs_total, batch_size): return sample('obs', GaussianMixture(mus, sigs, pis), obs=obs, sample_shape=(batch_size,))
def model(batch_X, batch_y=None, num_obs_total=None): """Defines the generative probabilistic model: p(y|z,X)p(z) The model is conditioned on the observed data :param batch_X: a batch of predictors :param batch_y: a batch of observations """ assert (jnp.ndim(batch_X) == 2) batch_size, d = jnp.shape(batch_X) num_obs_total = batch_size if num_obs_total is None else num_obs_total assert (batch_y is None or example_count(batch_y) == batch_size) z_w = sample('w', dist.Normal(jnp.zeros((d, )), jnp.ones( (d, )))) # prior is N(0,I) z_intercept = sample('intercept', dist.Normal(0, 1)) # prior is N(0,1) logits = batch_X.dot(z_w) + z_intercept with plate("batch", num_obs_total, batch_size): return sample('obs', dist.Bernoulli(logits=logits), obs=batch_y)
def model(x=None, num_obs_total=None): """ Args: x (jax.numpy.array): Array holding all features of a single data instance. num_obs_total (int): Number of total instances in the data set. Samples: site `x` similar to input x; array holding all features of a single data instance. """ assert x is None or len(np.shape(x)) == 2 if x is None: N = 1 else: N = np.shape(x)[0] if num_obs_total is None: num_obs_total = N assert isinstance(num_obs_total, int) and num_obs_total > 0 assert N <= num_obs_total leuko_mus = sample('Leukocytes_mus', dist.Normal(0., 1.)) leuko_sig = sample('Leukocytes_sig', dist.Gamma(2., 2.)) leuko_dist = dist.Normal(leuko_mus, leuko_sig) leuko_na_prob = sample('Leukocytes_na_prob', dist.Beta(1., 1.)) leuko_na_dist = NAModel(leuko_dist, leuko_na_prob) rhino_test_logit = sample('Rhinovirus/Enterovirus_logit', dist.Normal(0., 1.)) rhino_test_dist = dist.Bernoulli(logits=rhino_test_logit) rhino_test_na_prob = sample('Rhinovirus/Enterovirus_na_prob', dist.Beta(1., 1.)) rhino_test_na_dist = NAModel(rhino_test_dist, rhino_test_na_prob) with plate("batch", num_obs_total, N): x_leuko = get_feature(x, 0) x_rhino = get_feature(x, 1) y_leuko = sample('Leukocytes', leuko_na_dist, obs=x_leuko) y_rhino = sample('Rhinovirus/Enterovirus', rhino_test_na_dist, obs=x_rhino) y = sample_combined(y_leuko, y_rhino)
def model(obs=None, num_obs_total=None, d=None): """Defines the generative probabilistic model: p(x|z)p(z) """ if obs is not None: assert (jnp.ndim(obs) == 2) batch_size, d = jnp.shape(obs) else: assert (num_obs_total is not None) batch_size = num_obs_total assert (d != None) num_obs_total = batch_size if num_obs_total is None else num_obs_total z_mu = sample('mu', dist.Normal(jnp.zeros((d, )), 1.)) x_var = .1 with plate('batch', num_obs_total, batch_size): x = sample('obs', dist.Normal(z_mu, x_var).to_event(1), obs=obs, sample_shape=(batch_size, )) return x
def model(x=None, num_obs_total=None): assert x is None or len(jnp.shape(x)) == 2 if x is None: N = 1 else: N = jnp.shape(x)[0] if num_obs_total is None: num_obs_total = N assert isinstance(num_obs_total, int) and num_obs_total > 0 assert N <= num_obs_total mixture_dists = [] dtypes = [] for feature in features: prior_values = {} feature_prior_dists = create_feature_prior_dists(feature, k) for feature_prior_param, feature_prior_dist in feature_prior_dists.items(): prior_values[feature_prior_param] = sample( "{}_{}".format(feature.name, feature_prior_param), feature_prior_dist ) dtypes.append(feature.distribution.support_dtype) feature_dist = feature.instantiate(**prior_values) feature_dist = TypedDistribution(feature_dist, dtypes[-1]) if feature._missing_values: feature_na_prob = sample( "{}_na_prob".format(feature.name), dists.Beta(2.*jnp.ones(k), 2.*jnp.ones(k)) ) feature_dist = NAModel(feature_dist, feature_na_prob) mixture_dists.append(feature_dist) pis = sample('pis', dists.Dirichlet(jnp.ones(k))) with plate('batch', num_obs_total, N): # with minibatch(N, num_obs_total=num_obs_total): mixture_model_dist = MixtureModel(mixture_dists, pis) x = sample('x', mixture_model_dist, obs=x) return x