def guide(obs=None, num_obs_total=None, d=None): """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|x) """ # # very smart guide: starts with analytical solution # assert(obs != None) # mu_loc, mu_std = analytical_solution(obs) # mu_loc = param('mu_loc', mu_loc) # mu_std = jnp.exp(param('mu_std_log', jnp.log(mu_std))) # not so smart guide: starts from prior for mu assert (d != None) mu_loc = param('mu_loc', jnp.zeros(d)) mu_std = jnp.exp(param('mu_std_log', jnp.zeros(d))) z_mu = sample('mu', dist.Normal(mu_loc, mu_std)) return z_mu, mu_loc, mu_std
def guide(k, obs=None, num_obs_total=None, d=None): # the latent MixGaus distribution which learns the parameters if obs is not None: assert(jnp.ndim(obs) == 2) _, d = jnp.shape(obs) else: assert(num_obs_total is not None) assert(d is not None) alpha_log = param('alpha_log', jnp.zeros(k)) alpha = jnp.exp(alpha_log) pis = sample('pis', dist.Dirichlet(alpha)) mus_loc = param('mus_loc', jnp.zeros((k, d))) mus = sample('mus', dist.Normal(mus_loc, 1.)) sigs = sample('sigs', dist.InverseGamma(1., 1.), obs=jnp.ones_like(mus)) return pis, mus, sigs
def guide(batch_X, batch_y=None, num_obs_total=None): """Defines the probabilistic guide for z (variational approximation to posterior): q(z) ~ p(z|x) """ # we are interested in the posterior of w and intercept # since this is a fairly simple model, we just initialize them according # to our prior believe and let the optimization handle the rest assert(jnp.ndim(batch_X) == 2) d = jnp.shape(batch_X)[1] z_w_loc = param("w_loc", jnp.zeros((d,))) z_w_std = jnp.exp(param("w_std_log", jnp.zeros((d,)))) z_w = sample('w', dist.Normal(z_w_loc, z_w_std)) z_intercept_loc = param("intercept_loc", 0.) z_interpet_std = jnp.exp(param("intercept_std_log", 0.)) z_intercept = sample('intercept', dist.Normal(z_intercept_loc, z_interpet_std)) return (z_w, z_intercept)
def guide(z=None, num_obs_total=None) -> None: batch_size = 1 if z is not None: batch_size = z.shape[0] if num_obs_total is None: num_obs_total = batch_size mu_param = param('mu_param', 0.) sample('mu', dists.Normal(mu_param, 1.).expand_by((d, )).to_event(1)) sample('sigma', dists.InverseGamma(1.).expand_by((d, )).to_event(1))
def guide(d): mu_loc = param('mu_loc', jnp.zeros(1)) mu = sample('mu', self.DistWithIntermediate(), sample_shape=(1, d))
def guide(d): mu_loc = param('mu_loc', jnp.zeros(d)) mu = sample('mu', dist.Normal(mu_loc))