Exemplo n.º 1
0
        def model_fn(X, N=None, num_obs_total=None):
            if N is None:
                N = jnp.shape(X)[0]
            if num_obs_total is None:
                num_obs_total = N

            mu = sample("theta", dist.Normal(1.))
            with minibatch(N, num_obs_total=num_obs_total):
                X = sample("X", dist.Normal(mu), obs=X, sample_shape=(N, ))
            return X, mu
Exemplo n.º 2
0
    def guide(z=None, num_obs_total=None) -> None:
        batch_size = 1
        if z is not None:
            batch_size = z.shape[0]
        if num_obs_total is None:
            num_obs_total = batch_size

        mu_param = param('mu_param', 0.)
        sample('mu', dists.Normal(mu_param, 1.).expand_by((d, )).to_event(1))
        sample('sigma', dists.InverseGamma(1.).expand_by((d, )).to_event(1))
def model(z=None) -> None:
    batch_size = 1
    if z is not None:
        batch_size = z.shape[0]

    mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1))
    sigma = sample('sigma',
                   dists.InverseGamma(1.).expand_by((2, )).to_event(1))
    with plate('batch', batch_size, batch_size):
        sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
Exemplo n.º 4
0
    def model(z = None, num_obs_total = None) -> None:
        batch_size = 1
        if z is not None:
            batch_size = z.shape[0]
        if num_obs_total is None:
            num_obs_total = batch_size

        mu = sample('mu', dists.Normal(args.prior_mu).expand_by((d,)).to_event(1))
        sigma = sample('sigma', dists.InverseGamma(1.).expand_by((d,)).to_event(1))
        with plate('batch', num_obs_total, batch_size):
            sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
Exemplo n.º 5
0
def model(z=None, z2=None, num_obs_total=None) -> None:
    batch_size = 1
    if z is not None:
        batch_size = z.shape[0]
        assert (z.shape is not None)
        assert (z.shape[0] == z2.shape[0])
    if num_obs_total is None:
        num_obs_total = batch_size

    mu = sample('mu', dists.Normal().expand_by((2, )).to_event(1))
    sigma = sample('sigma',
                   dists.InverseGamma(1.).expand_by((2, )).to_event(1))
    with plate('batch', num_obs_total, batch_size):
        sample('x', dists.Normal(mu, sigma).to_event(1), obs=z)
Exemplo n.º 6
0
def model(k, obs=None, num_obs_total=None, d=None):
    # this is our model function using the GaussianMixture distribution
    # with prior belief
    if obs is not None:
        assert(jnp.ndim(obs) == 2)
        batch_size, d = jnp.shape(obs)
    else:
        assert(num_obs_total is not None)
        batch_size = num_obs_total
        assert(d is not None)
    num_obs_total = batch_size if num_obs_total is None else num_obs_total

    pis = sample('pis', dist.Dirichlet(jnp.ones(k)))
    mus = sample('mus', dist.Normal(jnp.zeros((k, d)), 10.))
    sigs = sample('sigs', dist.InverseGamma(1., 1.), sample_shape=jnp.shape(mus))
    with plate('batch', num_obs_total, batch_size):
        return sample('obs', GaussianMixture(mus, sigs, pis), obs=obs, sample_shape=(batch_size,))
Exemplo n.º 7
0
def guide(k, obs=None, num_obs_total=None, d=None):
    # the latent MixGaus distribution which learns the parameters
    if obs is not None:
        assert(jnp.ndim(obs) == 2)
        _, d = jnp.shape(obs)
    else:
        assert(num_obs_total is not None)
        assert(d is not None)

    alpha_log = param('alpha_log', jnp.zeros(k))
    alpha = jnp.exp(alpha_log)
    pis = sample('pis', dist.Dirichlet(alpha))

    mus_loc = param('mus_loc', jnp.zeros((k, d)))
    mus = sample('mus', dist.Normal(mus_loc, 1.))
    sigs = sample('sigs', dist.InverseGamma(1., 1.), obs=jnp.ones_like(mus))
    return pis, mus, sigs
Exemplo n.º 8
0
def model(batch_X, batch_y=None, num_obs_total=None):
    """Defines the generative probabilistic model: p(y|z,X)p(z)

    The model is conditioned on the observed data
    :param batch_X: a batch of predictors
    :param batch_y: a batch of observations
    """
    assert(jnp.ndim(batch_X) == 2)
    batch_size, d = jnp.shape(batch_X)
    assert(batch_y is None or example_count(batch_y) == batch_size)

    z_w = sample('w', dist.Normal(jnp.zeros((d,)), jnp.ones((d,)))) # prior is N(0,I)
    z_intercept = sample('intercept', dist.Normal(0,1)) # prior is N(0,1)
    logits = batch_X.dot(z_w)+z_intercept

    with minibatch(batch_size, num_obs_total=num_obs_total):
        return sample('obs', dist.Bernoulli(logits=logits), obs=batch_y)
Exemplo n.º 9
0
def guide(batch_X, batch_y=None, num_obs_total=None):
    """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|x)
    """
    # we are interested in the posterior of w and intercept
    # since this is a fairly simple model, we just initialize them according
    # to our prior believe and let the optimization handle the rest
    assert(jnp.ndim(batch_X) == 2)
    d = jnp.shape(batch_X)[1]

    z_w_loc = param("w_loc", jnp.zeros((d,)))
    z_w_std = jnp.exp(param("w_std_log", jnp.zeros((d,))))
    z_w = sample('w', dist.Normal(z_w_loc, z_w_std))

    z_intercept_loc = param("intercept_loc", 0.)
    z_interpet_std = jnp.exp(param("intercept_std_log", 0.))
    z_intercept = sample('intercept', dist.Normal(z_intercept_loc, z_interpet_std))

    return (z_w, z_intercept)
Exemplo n.º 10
0
def model(obs=None, num_obs_total=None, d=None):
    """Defines the generative probabilistic model: p(x|z)p(z)
    """
    if obs is not None:
        assert (jnp.ndim(obs) == 2)
        batch_size, d = jnp.shape(obs)
    else:
        assert (num_obs_total is not None)
        batch_size = num_obs_total
        assert (d != None)

    z_mu = sample('mu', dist.Normal(jnp.zeros((d, )), 1.))
    x_var = .1
    with minibatch(batch_size, num_obs_total):
        x = sample('obs',
                   dist.Normal(z_mu, x_var),
                   obs=obs,
                   sample_shape=(batch_size, ))
    return x
Exemplo n.º 11
0
Arquivo: vae.py Projeto: byzhang/d3p
def model(batch_or_batchsize,
          z_dim,
          hidden_dim,
          out_dim=None,
          num_obs_total=None):
    """Defines the generative probabilistic model: p(x|z)p(z)

    The model is conditioned on the observed data

    :param batch: a batch of observations
    :param hidden_dim: dimensions of the hidden layers in the VAE
    :param z_dim: dimensions of the latent variable / code
    :param out_dim: number of dimensions in a single output sample (flattened)

    :return: (named) sample x from the model observation distribution p(x|z)p(z)
    """
    if is_int_scalar(batch_or_batchsize):
        batch = None
        batch_size = batch_or_batchsize
        if out_dim is None:
            raise ValueError("if no batch is provided, out_dim must be given")
    else:
        batch = batch_or_batchsize
        assert (jnp.ndim(batch) == 3)
        batch_size = jnp.shape(batch)[0]
        batch = jnp.reshape(
            batch, (batch_size, -1)
        )  # squash each data item into a one-dimensional array (preserving only the batch size on the first axis)
        out_dim = jnp.shape(batch)[1]

    decode = numpyro.module('decoder', decoder(hidden_dim, out_dim),
                            (batch_size, z_dim))
    with minibatch(batch_size, num_obs_total=num_obs_total):
        z = sample('z', dist.Normal(jnp.zeros((z_dim, )), jnp.ones(
            (z_dim, ))))  # prior on z is N(0,I)
        img_loc = decode(
            z
        )  # evaluate decoder (p(x|z)) on sampled z to get means for output bernoulli distribution
        x = sample(
            'obs', dist.Bernoulli(img_loc), obs=batch
        )  # outputs x are sampled from bernoulli distribution depending on z and conditioned on the observed data
        return x
Exemplo n.º 12
0
    def model(x=None, num_obs_total=None):
        assert x is None or len(jnp.shape(x)) == 2
        if x is None:
            N = 1
        else:
            N = jnp.shape(x)[0]
        if num_obs_total is None:
            num_obs_total = N

        assert isinstance(num_obs_total, int) and num_obs_total > 0
        assert N <= num_obs_total

        mixture_dists = []
        dtypes = []
        for feature in features:
            prior_values = {}
            feature_prior_dists = create_feature_prior_dists(feature, k)
            for feature_prior_param, feature_prior_dist in feature_prior_dists.items():
                prior_values[feature_prior_param] = sample(
                    "{}_{}".format(feature.name, feature_prior_param),
                    feature_prior_dist
                )

            dtypes.append(feature.distribution.support_dtype)
            feature_dist = feature.instantiate(**prior_values)
            feature_dist = TypedDistribution(feature_dist, dtypes[-1])
            if feature._missing_values:
                feature_na_prob = sample(
                    "{}_na_prob".format(feature.name),
                    dists.Beta(2.*jnp.ones(k), 2.*jnp.ones(k))
                )
                feature_dist = NAModel(feature_dist, feature_na_prob)

            mixture_dists.append(feature_dist)

        pis = sample('pis', dists.Dirichlet(jnp.ones(k)))
        with plate('batch', num_obs_total, N):
        # with minibatch(N, num_obs_total=num_obs_total):
            mixture_model_dist = MixtureModel(mixture_dists, pis)
            x = sample('x', mixture_model_dist, obs=x)
            return x
Exemplo n.º 13
0
def guide(obs=None, num_obs_total=None, d=None):
    """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|x)
    """
    # # very smart guide: starts with analytical solution
    # assert(obs != None)
    # mu_loc, mu_std = analytical_solution(obs)
    # mu_loc = param('mu_loc', mu_loc)
    # mu_std = jnp.exp(param('mu_std_log', jnp.log(mu_std)))

    # not so smart guide: starts from prior for mu
    assert (d != None)
    mu_loc = param('mu_loc', jnp.zeros(d))
    mu_std = jnp.exp(param('mu_std_log', jnp.zeros(d)))

    z_mu = sample('mu', dist.Normal(mu_loc, mu_std))
    return z_mu, mu_loc, mu_std
Exemplo n.º 14
0
def model(x_first=None, x_second=None, num_obs_total=None) -> None:
    batch_size = 1
    if x_first is not None:
        batch_size = x_first.shape[0]
    if num_obs_total is None:
        num_obs_total = batch_size

    mu = sample('mu', dists.Normal())
    sigma = sample('sigma', dists.InverseGamma(1.))
    with plate('batch', num_obs_total, batch_size):
        sample('x_first', dists.Normal(mu, sigma), obs=x_first)
        sample('x_second', dists.Normal(mu, sigma), obs=x_second)
Exemplo n.º 15
0
Arquivo: vae.py Projeto: byzhang/d3p
def guide(batch, z_dim, hidden_dim, out_dim=None, num_obs_total=None):
    """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|q)
    :param batch: a batch of observations
    :return: (named) sampled z from the variational (guide) distribution q(z)
    """
    assert (jnp.ndim(batch) == 3)
    batch_size = jnp.shape(batch)[0]
    batch = jnp.reshape(
        batch, (batch_size, -1)
    )  # squash each data item into a one-dimensional array (preserving only the batch size on the first axis)
    out_dim = jnp.shape(batch)[1]

    encode = numpyro.module('encoder', encoder(hidden_dim, z_dim),
                            (batch_size, out_dim))
    with minibatch(batch_size, num_obs_total=num_obs_total):
        z_loc, z_std = encode(
            batch)  # obtain mean and variance for q(z) ~ p(z|x) from encoder
        z = sample('z', dist.Normal(z_loc, z_std))  # z follows q(z)
        return z
Exemplo n.º 16
0
def model(x=None, num_obs_total=None):
    """
    Args:
        x (jax.numpy.array): Array holding all features of a single data instance.
        num_obs_total (int): Number of total instances in the data set.
    Samples:
        site `x` similar to input x; array holding all features of a single data instance.
    """
    assert x is None or len(np.shape(x)) == 2
    if x is None:
        N = 1
    else:
        N = np.shape(x)[0]
    if num_obs_total is None:
        num_obs_total = N

    assert isinstance(num_obs_total, int) and num_obs_total > 0
    assert N <= num_obs_total

    leuko_mus = sample('Leukocytes_mus', dist.Normal(0., 1.))
    leuko_sig = sample('Leukocytes_sig', dist.Gamma(2., 2.))
    leuko_dist = dist.Normal(leuko_mus, leuko_sig)

    leuko_na_prob = sample('Leukocytes_na_prob', dist.Beta(1., 1.))
    leuko_na_dist = NAModel(leuko_dist, leuko_na_prob)

    rhino_test_logit = sample('Rhinovirus/Enterovirus_logit',
                              dist.Normal(0., 1.))
    rhino_test_dist = dist.Bernoulli(logits=rhino_test_logit)

    rhino_test_na_prob = sample('Rhinovirus/Enterovirus_na_prob',
                                dist.Beta(1., 1.))
    rhino_test_na_dist = NAModel(rhino_test_dist, rhino_test_na_prob)

    with plate("batch", num_obs_total, N):
        x_leuko = get_feature(x, 0)
        x_rhino = get_feature(x, 1)

        y_leuko = sample('Leukocytes', leuko_na_dist, obs=x_leuko)
        y_rhino = sample('Rhinovirus/Enterovirus',
                         rhino_test_na_dist,
                         obs=x_rhino)
        y = sample_combined(y_leuko, y_rhino)
Exemplo n.º 17
0
 def model(N, d):
     mu = sample("mu", dist.Normal(jnp.zeros(d)))
     x = sample("x", dist.Normal(mu), sample_shape=(N,))
Exemplo n.º 18
0
 def model(N, d):
     x = sample("x", self.DistWithIntermediate(), sample_shape=(N, d))
Exemplo n.º 19
0
 def guide(d):
     mu_loc = param('mu_loc', jnp.zeros(1))
     mu = sample('mu', self.DistWithIntermediate(), sample_shape=(1, d))
Exemplo n.º 20
0
 def guide(d):
     mu_loc = param('mu_loc', jnp.zeros(d))
     mu = sample('mu', dist.Normal(mu_loc))
Exemplo n.º 21
0
 def test_model(batch_size, num_obs_total):
     with minibatch(batch_size, num_obs_total):
         sample('test',
                MinibatchTests.DummyDist(),
                sample_shape=(batch_size, ))
Exemplo n.º 22
0
 def test_model(X, num_obs_total):
     with minibatch(X, num_obs_total):
         sample('test',
                MinibatchTests.DummyDist(),
                sample_shape=X.shape)