def guide(difficulty=0.0): previous_sample = None for k in reversed(range(1, N + 1)): loc_q = numpyro.param( f"loc_q_{k}", lambda key: target_mus[k] + difficulty * (0.1 * random.normal(key) - 0.53), ) log_sig_q = numpyro.param( f"log_sig_q_{k}", lambda key: -0.5 * jnp.log(lambda_posts[k]) + difficulty * (0.1 * random.normal(key) - 0.53), ) sig_q = jnp.exp(log_sig_q) kappa_q = None if k != N: kappa_q = numpyro.param( "kappa_q_%d" % k, lambda key: target_kappas[k] + difficulty * (0.1 * random.normal(key) - 0.53), ) mean_function = loc_q if k == N else kappa_q * previous_sample + loc_q node_flagged = True if which_nodes_reparam[k - 1] == 1.0 else False Normal = dist.Normal if node_flagged else FakeNormal loc_latent = numpyro.sample(f"loc_latent_{k}", Normal(mean_function, sig_q)) previous_sample = loc_latent return previous_sample
def _get_posterior(self): loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix), jnp.identity(self.latent_dim) * self._init_scale, constraint=constraints.lower_cholesky) return dist.MultivariateNormal(loc, scale_tril=scale_tril)
def flax_module(name, nn_module, *, input_shape=None): """ Declare a :mod:`~flax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param flax.nn.Module nn_module: a `flax` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import flax # noqa: F401 except ImportError as e: raise ImportError("Looking like you want to use flax to declare " "nn modules. This is an experimental feature. " "You need to install `flax` to be able to use this feature. " "It can be installed with `pip install flax`.") from e module_key = name + '$params' nn_params = numpyro.param(module_key) if nn_params is None: if input_shape is None: raise ValueError('Valid value for `input_shape` needed to initialize.') # feed in dummy data to init params rng_key = numpyro.prng_key() _, nn_params = nn_module.init(rng_key, jnp.ones(input_shape)) # make sure that nn_params keep the same order after unflatten params_flat, tree_def = tree_flatten(nn_params) nn_params = tree_unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) return partial(nn_module.call, nn_params)
def _get_transform(self): loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix), np.identity(self.latent_size) * self._init_scale, constraint=constraints.lower_cholesky) return MultivariateAffineTransform(loc, scale_tril)
def haiku_module(name, nn, input_shape=None): """ Declare a :mod:`~haiku` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param haiku.Module nn: a `haiku` Module which has .init and .apply methods :param tuple input_shape: shape of the input taken by the neural network. :return: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array. """ try: import haiku # noqa: F401 except ImportError: raise ImportError("Looking like you want to use haiku to declare " "nn modules. This is an experimental feature. " "You need to install `haiku` to be able to use this feature. " "It can be installed with `pip install git+https://github.com/deepmind/dm-haiku`.") module_key = name + '$params' nn_params = numpyro.param(module_key) if nn_params is None: if input_shape is None: raise ValueError('Valid value for `input_shape` needed to initialize.') # feed in dummy data to init params rng_key = numpyro.sample(name + '$rng_key', PRNGIdentity()) nn_params = nn.init(rng_key, jnp.ones(input_shape)) numpyro.param(module_key, nn_params) return partial(nn.apply, nn_params, None)
def guide(): loc = numpyro.param("loc", np.zeros(())) scale = numpyro.param("scale", np.ones(()), constraint=constraints.positive) x = numpyro.sample("x", dist.Normal(loc, scale)) with numpyro.plate("plate", len(data)): with handlers.mask(mask=np.invert(mask)): numpyro.sample("y_unobserved", dist.Normal(x, 1.0))
def guide(): loc = numpyro.param("loc", np.zeros(3)) cov = numpyro.param("cov", np.eye(3), constraint=constraints.positive_definite) x = numpyro.sample("x", dist.MultivariateNormal(loc, cov)) with numpyro.plate("plate", len(data)): with handlers.mask(mask=np.invert(mask)): numpyro.sample("y_unobserved", dist.MultivariateNormal(x, np.eye(3)))
def guide(): alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17) beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log) numpyro.sample("lambda_latent", FakeGamma(alpha_q, beta_q)) with numpyro.plate("data", len(data)): pass
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.items(): if site["type"] != "sample" or isinstance( site["fn"], dist.PRNGIdentity) or site["is_observed"]: continue event_dim = self._event_dims[name] init_loc = self._init_locs[name] with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) site_loc = numpyro.param("{}_{}_loc".format(name, self.prefix), init_loc, event_dim=event_dim) site_scale = numpyro.param("{}_{}_scale".format( name, self.prefix), jnp.full(jnp.shape(init_loc), self._init_scale), constraint=constraints.positive, event_dim=event_dim) site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim) if site["fn"].support in [ constraints.real, constraints.real_vector ]: result[name] = numpyro.sample(name, site_fn) else: unconstrained_value = numpyro.sample( "{}_unconstrained".format(name), site_fn, infer={"is_auxiliary": True}) transform = biject_to(site['fn'].support) value = transform(unconstrained_value) log_density = -transform.log_abs_det_jacobian( unconstrained_value, value) log_density = sum_rightmost( log_density, jnp.ndim(log_density) - jnp.ndim(value) + site["fn"].event_dim) delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) result[name] = numpyro.sample(name, delta_dist) return result
def guide(data): alpha_q = numpyro.param("alpha_q", lambda key: random.normal(key), constraint=constraints.positive) beta_q = numpyro.param("beta_q", lambda key: random.exponential(key), constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
def _get_posterior(self): loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent) scale = numpyro.param( "{}_scale".format(self.prefix), jnp.full(self.latent_dim, self._init_scale), constraint=self.scale_constraint, ) return dist.Normal(loc, scale)
def _get_posterior(self): loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent) scale_tril = numpyro.param( "{}_scale_tril".format(self.prefix), jnp.identity(self.latent_dim) * self._init_scale, constraint=self.scale_tril_constraint, ) return dist.MultivariateNormal(loc, scale_tril=scale_tril)
def guide(): alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17) beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log) p_latent = numpyro.sample("p_latent", FakeBeta(alpha_q, beta_q)) with numpyro.plate("data", len(data)): pass return p_latent
def guide(): loc_q = numpyro.param("loc_q", analytic_loc_n + jnp.array([0.334, 0.334])) log_sig_q = numpyro.param( "log_sig_q", analytic_log_sig_n + jnp.array([-0.29, -0.29])) sig_q = jnp.exp(log_sig_q) with numpyro.plate("plate", 2): loc_latent = numpyro.sample("loc_latent", FakeNormal(loc_q, sig_q)) return loc_latent
def model(): loc1 = numpyro.param("loc1", 0.) scale1 = numpyro.param("scale1", 1., constraint=constraints.positive) numpyro.sample("latent1", dist.Normal(loc1, scale1)) loc2 = numpyro.param("loc2", 1.) scale2 = numpyro.param("scale2", 2., constraint=constraints.positive) latent2 = numpyro.sample("latent2", dist.Normal(loc2, scale2)) return latent2
def test_subsample_param(): data = jnp.arange(100.) subsample_size = 7 with handlers.seed(rng_seed=0): with numpyro.plate("a", len(data), subsample_size=subsample_size): p0 = numpyro.param("p0", 0., event_dim=0) assert jnp.shape(p0) == () p = numpyro.param("p", 0.5 * jnp.ones(len(data)), event_dim=0) assert len(p) == subsample_size
def model(z1=None, z2=None): p = numpyro.param("p", np.array([0.25, 0.75])) loc = numpyro.param("loc", jnp.array([-1.0, 1.0])) z1 = numpyro.sample("z1", dist.Categorical(p), obs=z1) with numpyro.plate("data[0]", 3): numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) with numpyro.plate("data[1]", 2): z2 = numpyro.sample("z2", dist.Categorical(p), obs=z2) numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
def gmm_guide(data, num_components=3): mus_val = numpyro.param('mus_val', jnp.array(stats.norm.rvs(size=num_components) * 1000), constraint=dist.constraints.real) sigmas_val = numpyro.param('sigmas_val', jnp.ones(num_components), constraint=dist.constraints.positive) mus = numpyro.sample('mus', dist.Delta(mus_val)) sigmas = numpyro.sample('sigmas', dist.Delta(sigmas_val)) mixture_probs_val = numpyro.param('mixture_probs_val', jax.nn.softmax(stats.norm.rvs(size=num_components)), constraint=dist.constraints.simplex) mixture_probs = numpyro.sample('mixture_probs', dist.Delta(mixture_probs_val))
def _get_posterior(self, *args, **kwargs): rank = int(round(self.latent_dim ** 0.5)) if self.rank is None else self.rank loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), jnp.zeros((self.latent_dim, rank))) scale = numpyro.param('{}_scale'.format(self.prefix), jnp.full(self.latent_dim, self._init_scale), constraint=constraints.positive) cov_diag = scale * scale cov_factor = cov_factor * scale[..., None] return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
def _sample_latent(self, base_dist, *args, **kwargs): sample_shape = kwargs.pop('sample_shape', ()) rank = int(round(self.latent_size ** 0.5)) if self.rank is None else self.rank loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent) cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), np.zeros((self.latent_size, rank))) scale = numpyro.param('{}_scale'.format(self.prefix), np.ones(self.latent_size)) cov_diag = scale * scale cov_factor = cov_factor * scale[..., None] posterior = dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) return numpyro.sample("_{}_latent".format(self.prefix), posterior, sample_shape=sample_shape)
def model(z1=None, z2=None): p = numpyro.param("p", jnp.array([[0.25, 0.75], [0.1, 0.9]])) loc = numpyro.param("loc", jnp.array([-1.0, 1.0])) z1 = numpyro.sample("z1", dist.Categorical(p[0]), obs=z1) z2 = numpyro.sample("z2", dist.Categorical(p[z1]), obs=z2) logger.info("z1.shape = {}".format(z1.shape)) logger.info("z2.shape = {}".format(z2.shape)) with numpyro.plate("data", 3): numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0]) numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
def model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) x_prev = 0 x_curr = 0 for t in markov(range(T), history=history): probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t), dist.Bernoulli(probs)) numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0) return x_prev, x_curr
def guide(X: DeviceArray): n_stores, n_days, n_features = X.shape n_features -= 1 # remove one dim for target plate_features = numpyro.plate(Plate.features, n_features, dim=-1) plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2) numpyro.sample( Site.disp_param_mu, dist.Normal(loc=model_params[Param.loc_disp_param_mu], scale=model_params[Param.scale_disp_param_mu])) numpyro.sample( Site.disp_param_sigma, dist.TransformedDistribution( dist.Normal( loc=model_params[Param.loc_disp_param_logsigma], scale=model_params[Param.scale_disp_param_logsigma]), transforms=dist.transforms.ExpTransform())) with plate_stores: numpyro.sample( Site.disp_param_offsets, dist.Normal(loc=numpyro.param(Param.loc_disp_param_offsets, jnp.zeros((n_stores, 1))), scale=numpyro.param( Param.scale_disp_param_offsets, 0.1 * jnp.ones((n_stores, 1)), constraint=dist.constraints.positive))) with plate_features: numpyro.sample( Site.coef_mus, dist.Normal(loc=model_params[Param.loc_coef_mus], scale=model_params[Param.scale_coef_mus])) numpyro.sample( Site.coef_sigmas, dist.TransformedDistribution( dist.Normal( loc=model_params[Param.loc_coef_logsigmas], scale=model_params[Param.scale_coef_logsigmas]), transforms=dist.transforms.ExpTransform())) with plate_stores: numpyro.sample( Site.coef_offsets, dist.Normal(loc=numpyro.param( Param.loc_coef_offsets, jnp.zeros((n_stores, n_features))), scale=numpyro.param( Param.scale_coef_offsets, 0.5 * jnp.ones((n_stores, n_features)), constraint=dist.constraints.positive)))
def guide(): m1 = numpyro.param("m1", 2.0) s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) m2 = numpyro.param("m2", 2.0) s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) def true_fun(_): numpyro.sample("x", dist.Normal(m1, s1)) def false_fun(_): numpyro.sample("x", dist.Normal(m2, s2)) cluster = numpyro.sample("cluster", dist.Normal()) cond(cluster > 0, true_fun, false_fun, None)
def guide( x: Optional[jnp.ndarray] = None, seq_len: int = 0, batch: int = 0, x_dim: int = 1, future_steps: int = 0, ) -> None: if x is not None: *_, x_dim = x.shape phi = numpyro.param("phi", jnp.ones(x_dim)) sigma = numpyro.param("sigma", jnp.ones(x_dim) * 0.05, constraint=constraints.positive) numpyro.sample("z", dist.Normal(x * phi, sigma))
def guide(self): if self.fit_rho: rho_loc = npy.param( Sites.RHO + Sites.LOC, jnp.tile(self.rho_loc, (self.num_ltla, 1)), ) rho_scale = npy.param( Sites.RHO + Sites.SCALE, jnp.tile(self.init_scale * self.rho_scale, (self.num_ltla, 1)), constraint=dist.constraints.positive, ) npy.sample(Sites.RHO, dist.Normal(rho_loc, rho_scale)) # mean / sd for parameter s beta_loc = npy.param( Sites.BETA + Sites.LOC, jnp.tile(self.beta_loc, (self.num_ltla_lin, self.num_basis)), ) beta_scale = npy.param( Sites.BETA + Sites.SCALE, self.init_scale * self.beta_scale * jnp.stack(self.num_ltla_lin * [jnp.eye(self.num_basis)]), constraint=dist.constraints.lower_cholesky, ) npy.sample(Sites.BETA, dist.MultivariateNormal(beta_loc, scale_tril=beta_scale)) b0_loc = npy.param( Sites.BC0 + Sites.LOC, jnp.concatenate([ jnp.repeat(self.b0_loc, self.num_lin), ]), ) b0_scale = npy.param( Sites.BC0 + Sites.SCALE, jnp.diag( jnp.concatenate([ jnp.repeat( self.init_scale * self.b0_scale * self.time_scale, self.num_lin, ), ])), constraint=dist.constraints.lower_cholesky, ) npy.sample(Sites.B0, dist.MultivariateNormal(b0_loc, scale_tril=b0_scale)) c_loc = npy.param( Sites.C + Sites.LOC, jnp.tile(self.c_loc, (self.num_ltla_lin, self.num_lin))) c_scale = npy.param( Sites.C + Sites.SCALE, jnp.tile(self.init_scale * self.c_scale, (self.num_ltla_lin, self.num_lin)), ) npy.sample(Sites.C, dist.Normal(c_loc, c_scale))
def __call__(self, *args, **kwargs): if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.items(): if site["type"] != "sample" or site["is_observed"]: continue event_dim = self._event_dims[name] init_loc = self._init_locs[name] with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) site_loc = numpyro.param( "{}_{}_loc".format(name, self.prefix), init_loc, constraint=site["fn"].support, event_dim=event_dim, ) site_fn = dist.Delta(site_loc).to_event(event_dim) result[name] = numpyro.sample(name, site_fn) return result
def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" centered = self.centered if is_identically_one(centered): return name, fn, obs event_shape = fn.event_shape fn, batch_shape, event_dim = self._unwrap(fn) # Apply a partial decentering transform. params = {key: getattr(fn, key) for key in self.shape_params} if self.centered is None: centered = numpyro.param("{}_centered".format(name), jnp.full(event_shape, 0.5), constraint=constraints.unit_interval) params["loc"] = fn.loc * centered params["scale"] = fn.scale**centered decentered_fn = self._wrap(type(fn)(**params), batch_shape, event_dim) # Draw decentered noise. decentered_value = numpyro.sample("{}_decentered".format(name), decentered_fn) # Differentiably transform. delta = decentered_value - centered * fn.loc value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta # Simulate a pyro.deterministic() site. return None, value
def fun_model(): p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) q = numpyro.param("q", 0.25 * jnp.ones(2)) z = numpyro.sample("z", dist.Bernoulli(0.5)) def transition_fn(carry, y): x_prev, x_curr = carry probs = p[x_prev, x_curr, z] x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs)) numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y) return (x_prev, x_curr), None (x_prev, x_curr), _ = scan(transition_fn, (0, 0), jnp.zeros(T), history=history) return x_prev, x_curr
def model(z=None): p = numpyro.param("p", np.array([0.75, 0.25])) iz = numpyro.sample("z", dist.Categorical(p), obs=z) z = jnp.array([0.0, 1.0])[iz] logger.info("z.shape = {}".format(z.shape)) with numpyro.plate("data", 3): numpyro.sample("x", dist.Normal(z, 1.0), obs=data)