def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.) # transformed distribution def model(): numpyro.sample('x', x_prior) numpyro.sample('y', y_prior) params = {'x': jnp.array(-5.), 'y': jnp.array(7.)} model = handlers.seed(model, random.PRNGKey(0)) inv_transforms = { 'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support) } expected_samples = partial(transform_fn, inv_transforms)(params) expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) - y_prior.log_prob(expected_samples['y']) - inv_transforms['x'].log_abs_det_jacobian( params['x'], expected_samples['x']) - inv_transforms['y'].log_abs_det_jacobian( params['y'], expected_samples['y'])) reparam_model = handlers.reparam(model, {'y': TransformReparam()}) base_params = {'x': params['x'], 'y_base': params['y']} actual_samples = constrain_fn(handlers.seed(reparam_model, random.PRNGKey(0)), (), {}, base_params, return_deterministic=True) actual_potential_energy = potential_energy(reparam_model, (), {}, base_params) assert_allclose(expected_samples['x'], actual_samples['x']) assert_allclose(expected_samples['y'], actual_samples['y']) assert_allclose(actual_potential_energy, expected_potential_energy)
def test_log_normal(batch_shape, event_shape): shape = batch_shape + event_shape loc = np.random.rand(*shape) * 2 - 1 scale = np.random.rand(*shape) + 0.5 def model(): fn = dist.TransformedDistribution( dist.Normal(jnp.zeros_like(loc), jnp.ones_like(scale)), [AffineTransform(loc, scale), ExpTransform()]) if event_shape: fn = fn.to_event(len(event_shape)).expand_by([100000]) with numpyro.plate_stack("plates", batch_shape): with numpyro.plate("particles", 100000): return numpyro.sample("x", fn) with handlers.trace() as tr: value = handlers.seed(model, 0)() expected_moments = get_moments(jnp.log(value)) with numpyro.handlers.reparam(config={"x": TransformReparam()}): with handlers.trace() as tr: value = handlers.seed(model, 0)() assert tr["x"]["type"] == "deterministic" actual_moments = get_moments(jnp.log(value)) assert_allclose(actual_moments, expected_moments, atol=0.05, rtol=0.01)
def reparam_model(dim=10): y = numpyro.sample('y', dist.Normal(0, 3)) with numpyro.handlers.reparam(config={'x': TransformReparam()}): numpyro.sample( 'x', dist.TransformedDistribution(dist.Normal(jnp.zeros(dim - 1), 1), AffineTransform(0, jnp.exp(y / 2))))
def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
def actual_model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)), ) with numpyro.plate("N", len(data)): numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
def model_noncentered(num: int, sigma: np.ndarray, y: Optional[np.ndarray] = None) -> None: mu = numpyro.sample("mu", dist.Normal(0, 5)) tau = numpyro.sample("tau", dist.HalfCauchy(5)) with numpyro.plate("num", num): with numpyro.handlers.reparam(config={"theta": TransformReparam()}): theta = numpyro.sample( "theta", dist.TransformedDistribution( dist.Normal(0.0, 1.0), dist.transforms.AffineTransform(mu, tau)), ) numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
def model(X: DeviceArray) -> DeviceArray: """Gamma-Poisson hierarchical model for daily sales forecasting Args: X: input data Returns: output data """ n_stores, n_days, n_features = X.shape n_features -= 1 # remove one dim for target eps = 1e-12 # epsilon plate_features = numpyro.plate(Plate.features, n_features, dim=-1) plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2) plate_days = numpyro.plate(Plate.days, n_days, dim=-1) disp_param_mu = numpyro.sample(Site.disp_param_mu, dist.Normal(loc=4.0, scale=1.0)) disp_param_sigma = numpyro.sample(Site.disp_param_sigma, dist.HalfNormal(scale=1.0)) with plate_stores: with numpyro.handlers.reparam( config={Site.disp_params: TransformReparam()}): disp_params = numpyro.sample( Site.disp_params, dist.TransformedDistribution( dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1), dist.transforms.AffineTransform(disp_param_mu, disp_param_sigma), ), ) with plate_features: coef_mus = numpyro.sample( Site.coef_mus, dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features)), ) coef_sigmas = numpyro.sample( Site.coef_sigmas, dist.HalfNormal(scale=2.0 * jnp.ones(n_features))) with plate_stores: with numpyro.handlers.reparam( config={Site.coefs: TransformReparam()}): coefs = numpyro.sample( Site.coefs, dist.TransformedDistribution( dist.Normal(loc=jnp.zeros((n_stores, n_features)), scale=1.0), dist.transforms.AffineTransform(coef_mus, coef_sigmas), ), ) with plate_days, plate_stores: targets = X[..., -1] features = jnp.nan_to_num(X[..., :-1]) # padded features to 0 is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets), jnp.ones_like(targets)) not_observed = 1 - is_observed means = (is_observed * jnp.exp( jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) + not_observed * eps) betas = is_observed * jnp.exp(-disp_params) + not_observed alphas = means * betas return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas), obs=jnp.nan_to_num(targets))
def guide(X: DeviceArray): n_stores, n_days, n_features = X.shape n_features -= 1 # remove one dim for target plate_features = numpyro.plate(Plate.features, n_features, dim=-1) plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2) disp_param_mu = numpyro.sample( Site.disp_param_mu, dist.Normal( loc=model_params[Param.loc_disp_param_mu], scale=model_params[Param.scale_disp_param_mu], ), ) disp_param_sigma = numpyro.sample( Site.disp_param_sigma, dist.TransformedDistribution( dist.Normal( loc=model_params[Param.loc_disp_param_logsigma], scale=model_params[Param.scale_disp_param_logsigma], ), transforms=dist.transforms.ExpTransform(), ), ) with plate_stores: with numpyro.handlers.reparam( config={Site.disp_params: TransformReparam()}): numpyro.sample( Site.disp_params, dist.TransformedDistribution( dist.Normal( loc=numpyro.param(Param.loc_disp_params, jnp.zeros((n_stores, 1))), scale=numpyro.param( Param.scale_disp_params, 0.1 * jnp.ones((n_stores, 1)), constraint=dist.constraints.positive, ), ), dist.transforms.AffineTransform( disp_param_mu, disp_param_sigma), ), ) with plate_features: coef_mus = numpyro.sample( Site.coef_mus, dist.Normal( loc=model_params[Param.loc_coef_mus], scale=model_params[Param.scale_coef_mus], ), ) coef_sigmas = numpyro.sample( Site.coef_sigmas, dist.TransformedDistribution( dist.Normal( loc=model_params[Param.loc_coef_logsigmas], scale=model_params[Param.scale_coef_logsigmas], ), transforms=dist.transforms.ExpTransform(), ), ) with plate_stores: with numpyro.handlers.reparam( config={Site.coefs: TransformReparam()}): numpyro.sample( Site.coefs, dist.TransformedDistribution( dist.Normal( loc=numpyro.param( Param.loc_coefs, jnp.zeros((n_stores, n_features))), scale=numpyro.param( Param.scale_coefs, 0.5 * jnp.ones((n_stores, n_features)), constraint=dist.constraints.positive, ), ), dist.transforms.AffineTransform( coef_mus, coef_sigmas), ), )
def model(X: DeviceArray): n_stores, n_days, n_features = X.shape n_features -= 1 # remove one dim for target plate_features = numpyro.plate(Plate.features, n_features, dim=-1) plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2) plate_days = numpyro.plate(Plate.days, n_days, dim=-1) disp_param_mu = numpyro.sample( Site.disp_param_mu, dist.Normal( loc=model_params[Param.loc_disp_param_mu], scale=model_params[Param.scale_disp_param_mu], ), ) disp_param_sigma = numpyro.sample( Site.disp_param_sigma, dist.TransformedDistribution( dist.Normal( loc=model_params[Param.loc_disp_param_logsigma], scale=model_params[Param.scale_disp_param_logsigma], ), transforms=dist.transforms.ExpTransform(), ), ) with plate_stores: with numpyro.handlers.reparam( config={Site.disp_params: TransformReparam()}): disp_params = numpyro.sample( Site.disp_params, dist.TransformedDistribution( dist.Normal( loc=model_params[Param.loc_disp_params], scale=model_params[Param.scale_disp_params], ), dist.transforms.AffineTransform( disp_param_mu, disp_param_sigma), ), ) with plate_features: coef_mus = numpyro.sample( Site.coef_mus, dist.Normal( loc=model_params[Param.loc_coef_mus], scale=model_params[Param.scale_coef_mus], ), ) coef_sigmas = numpyro.sample( Site.coef_sigmas, dist.TransformedDistribution( dist.Normal( loc=model_params[Param.loc_coef_logsigmas], scale=model_params[Param.scale_coef_logsigmas], ), transforms=dist.transforms.ExpTransform(), ), ) with plate_stores: with numpyro.handlers.reparam( config={Site.coefs: TransformReparam()}): coefs = numpyro.sample( Site.coefs, dist.TransformedDistribution( dist.Normal( loc=model_params[Param.loc_coefs], scale=model_params[Param.scale_coefs], ), dist.transforms.AffineTransform( coef_mus, coef_sigmas), ), ) with plate_days, plate_stores: features = jnp.nan_to_num(X[..., :-1]) means = jnp.exp( jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) betas = jnp.exp(-disp_params) alphas = means * betas return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas))