Example #1
0
def guide(obs=None, num_obs_total=None, d=None):
    """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|x)
    """
    # # very smart guide: starts with analytical solution
    # assert(obs != None)
    # mu_loc, mu_std = analytical_solution(obs)
    # mu_loc = param('mu_loc', mu_loc)
    # mu_std = jnp.exp(param('mu_std_log', jnp.log(mu_std)))

    # not so smart guide: starts from prior for mu
    assert (d != None)
    mu_loc = param('mu_loc', jnp.zeros(d))
    mu_std = jnp.exp(param('mu_std_log', jnp.zeros(d)))

    z_mu = sample('mu', dist.Normal(mu_loc, mu_std))
    return z_mu, mu_loc, mu_std
Example #2
0
def guide(k, obs=None, num_obs_total=None, d=None):
    # the latent MixGaus distribution which learns the parameters
    if obs is not None:
        assert(jnp.ndim(obs) == 2)
        _, d = jnp.shape(obs)
    else:
        assert(num_obs_total is not None)
        assert(d is not None)

    alpha_log = param('alpha_log', jnp.zeros(k))
    alpha = jnp.exp(alpha_log)
    pis = sample('pis', dist.Dirichlet(alpha))

    mus_loc = param('mus_loc', jnp.zeros((k, d)))
    mus = sample('mus', dist.Normal(mus_loc, 1.))
    sigs = sample('sigs', dist.InverseGamma(1., 1.), obs=jnp.ones_like(mus))
    return pis, mus, sigs
Example #3
0
def guide(batch_X, batch_y=None, num_obs_total=None):
    """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|x)
    """
    # we are interested in the posterior of w and intercept
    # since this is a fairly simple model, we just initialize them according
    # to our prior believe and let the optimization handle the rest
    assert(jnp.ndim(batch_X) == 2)
    d = jnp.shape(batch_X)[1]

    z_w_loc = param("w_loc", jnp.zeros((d,)))
    z_w_std = jnp.exp(param("w_std_log", jnp.zeros((d,))))
    z_w = sample('w', dist.Normal(z_w_loc, z_w_std))

    z_intercept_loc = param("intercept_loc", 0.)
    z_interpet_std = jnp.exp(param("intercept_std_log", 0.))
    z_intercept = sample('intercept', dist.Normal(z_intercept_loc, z_interpet_std))

    return (z_w, z_intercept)
    def guide(z=None, num_obs_total=None) -> None:
        batch_size = 1
        if z is not None:
            batch_size = z.shape[0]
        if num_obs_total is None:
            num_obs_total = batch_size

        mu_param = param('mu_param', 0.)
        sample('mu', dists.Normal(mu_param, 1.).expand_by((d, )).to_event(1))
        sample('sigma', dists.InverseGamma(1.).expand_by((d, )).to_event(1))
Example #5
0
 def guide(d):
     mu_loc = param('mu_loc', jnp.zeros(1))
     mu = sample('mu', self.DistWithIntermediate(), sample_shape=(1, d))
Example #6
0
 def guide(d):
     mu_loc = param('mu_loc', jnp.zeros(d))
     mu = sample('mu', dist.Normal(mu_loc))