def model(z=None) -> None:
    batch_size = 1
    if z is not None:
        batch_size = z.shape[0]

    mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1))
    sigma = sample('sigma',
                   dists.InverseGamma(1.).expand_by((2, )).to_event(1))
    with plate('batch', batch_size, batch_size):
        sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
示例#2
0
    def model(z = None, num_obs_total = None) -> None:
        batch_size = 1
        if z is not None:
            batch_size = z.shape[0]
        if num_obs_total is None:
            num_obs_total = batch_size

        mu = sample('mu', dists.Normal(args.prior_mu).expand_by((d,)).to_event(1))
        sigma = sample('sigma', dists.InverseGamma(1.).expand_by((d,)).to_event(1))
        with plate('batch', num_obs_total, batch_size):
            sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
def model(x_first=None, x_second=None, num_obs_total=None) -> None:
    batch_size = 1
    if x_first is not None:
        batch_size = x_first.shape[0]
    if num_obs_total is None:
        num_obs_total = batch_size

    mu = sample('mu', dists.Normal())
    sigma = sample('sigma', dists.InverseGamma(1.))
    with plate('batch', num_obs_total, batch_size):
        sample('x_first', dists.Normal(mu, sigma), obs=x_first)
        sample('x_second', dists.Normal(mu, sigma), obs=x_second)
示例#4
0
def model(z=None, z2=None, num_obs_total=None) -> None:
    batch_size = 1
    if z is not None:
        batch_size = z.shape[0]
        assert (z.shape is not None)
        assert (z.shape[0] == z2.shape[0])
    if num_obs_total is None:
        num_obs_total = batch_size

    mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1))
    sigma = sample('sigma',
                   dists.InverseGamma(1.).expand_by((2, )).to_event(1))
    with plate('batch', num_obs_total, batch_size):
        sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
示例#5
0
def model(k, obs=None, num_obs_total=None, d=None):
    # this is our model function using the GaussianMixture distribution
    # with prior belief
    if obs is not None:
        assert(jnp.ndim(obs) == 2)
        batch_size, d = jnp.shape(obs)
    else:
        assert(num_obs_total is not None)
        batch_size = num_obs_total
        assert(d is not None)
    num_obs_total = batch_size if num_obs_total is None else num_obs_total

    pis = sample('pis', dist.Dirichlet(jnp.ones(k)))
    mus = sample('mus', dist.Normal(jnp.zeros((k, d)), 10.))
    sigs = sample('sigs', dist.InverseGamma(1., 1.), sample_shape=jnp.shape(mus))
    with plate('batch', num_obs_total, batch_size):
        return sample('obs', GaussianMixture(mus, sigs, pis), obs=obs, sample_shape=(batch_size,))
示例#6
0
def model(batch_X, batch_y=None, num_obs_total=None):
    """Defines the generative probabilistic model: p(y|z,X)p(z)

    The model is conditioned on the observed data
    :param batch_X: a batch of predictors
    :param batch_y: a batch of observations
    """
    assert (jnp.ndim(batch_X) == 2)
    batch_size, d = jnp.shape(batch_X)
    num_obs_total = batch_size if num_obs_total is None else num_obs_total
    assert (batch_y is None or example_count(batch_y) == batch_size)

    z_w = sample('w', dist.Normal(jnp.zeros((d, )), jnp.ones(
        (d, ))))  # prior is N(0,I)
    z_intercept = sample('intercept', dist.Normal(0, 1))  # prior is N(0,1)
    logits = batch_X.dot(z_w) + z_intercept

    with plate("batch", num_obs_total, batch_size):
        return sample('obs', dist.Bernoulli(logits=logits), obs=batch_y)
示例#7
0
def model(x=None, num_obs_total=None):
    """
    Args:
        x (jax.numpy.array): Array holding all features of a single data instance.
        num_obs_total (int): Number of total instances in the data set.
    Samples:
        site `x` similar to input x; array holding all features of a single data instance.
    """
    assert x is None or len(np.shape(x)) == 2
    if x is None:
        N = 1
    else:
        N = np.shape(x)[0]
    if num_obs_total is None:
        num_obs_total = N

    assert isinstance(num_obs_total, int) and num_obs_total > 0
    assert N <= num_obs_total

    leuko_mus = sample('Leukocytes_mus', dist.Normal(0., 1.))
    leuko_sig = sample('Leukocytes_sig', dist.Gamma(2., 2.))
    leuko_dist = dist.Normal(leuko_mus, leuko_sig)

    leuko_na_prob = sample('Leukocytes_na_prob', dist.Beta(1., 1.))
    leuko_na_dist = NAModel(leuko_dist, leuko_na_prob)

    rhino_test_logit = sample('Rhinovirus/Enterovirus_logit',
                              dist.Normal(0., 1.))
    rhino_test_dist = dist.Bernoulli(logits=rhino_test_logit)

    rhino_test_na_prob = sample('Rhinovirus/Enterovirus_na_prob',
                                dist.Beta(1., 1.))
    rhino_test_na_dist = NAModel(rhino_test_dist, rhino_test_na_prob)

    with plate("batch", num_obs_total, N):
        x_leuko = get_feature(x, 0)
        x_rhino = get_feature(x, 1)

        y_leuko = sample('Leukocytes', leuko_na_dist, obs=x_leuko)
        y_rhino = sample('Rhinovirus/Enterovirus',
                         rhino_test_na_dist,
                         obs=x_rhino)
        y = sample_combined(y_leuko, y_rhino)
示例#8
0
def model(obs=None, num_obs_total=None, d=None):
    """Defines the generative probabilistic model: p(x|z)p(z)
    """
    if obs is not None:
        assert (jnp.ndim(obs) == 2)
        batch_size, d = jnp.shape(obs)
    else:
        assert (num_obs_total is not None)
        batch_size = num_obs_total
        assert (d != None)
    num_obs_total = batch_size if num_obs_total is None else num_obs_total

    z_mu = sample('mu', dist.Normal(jnp.zeros((d, )), 1.))
    x_var = .1
    with plate('batch', num_obs_total, batch_size):
        x = sample('obs',
                   dist.Normal(z_mu, x_var).to_event(1),
                   obs=obs,
                   sample_shape=(batch_size, ))
    return x
示例#9
0
    def model(x=None, num_obs_total=None):
        assert x is None or len(jnp.shape(x)) == 2
        if x is None:
            N = 1
        else:
            N = jnp.shape(x)[0]
        if num_obs_total is None:
            num_obs_total = N

        assert isinstance(num_obs_total, int) and num_obs_total > 0
        assert N <= num_obs_total

        mixture_dists = []
        dtypes = []
        for feature in features:
            prior_values = {}
            feature_prior_dists = create_feature_prior_dists(feature, k)
            for feature_prior_param, feature_prior_dist in feature_prior_dists.items():
                prior_values[feature_prior_param] = sample(
                    "{}_{}".format(feature.name, feature_prior_param),
                    feature_prior_dist
                )

            dtypes.append(feature.distribution.support_dtype)
            feature_dist = feature.instantiate(**prior_values)
            feature_dist = TypedDistribution(feature_dist, dtypes[-1])
            if feature._missing_values:
                feature_na_prob = sample(
                    "{}_na_prob".format(feature.name),
                    dists.Beta(2.*jnp.ones(k), 2.*jnp.ones(k))
                )
                feature_dist = NAModel(feature_dist, feature_na_prob)

            mixture_dists.append(feature_dist)

        pis = sample('pis', dists.Dirichlet(jnp.ones(k)))
        with plate('batch', num_obs_total, N):
        # with minibatch(N, num_obs_total=num_obs_total):
            mixture_model_dist = MixtureModel(mixture_dists, pis)
            x = sample('x', mixture_model_dist, obs=x)
            return x