Example #1
0
 def actual_model(data):
     alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
     with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
         loc = numpyro.sample(
             "loc",
             dist.TransformedDistribution(
                 dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)),
         )
     with numpyro.plate("N", len(data)):
         numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
Example #2
0
def model_noncentered(num: int,
                      sigma: np.ndarray,
                      y: Optional[np.ndarray] = None) -> None:

    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("num", num):
        with numpyro.handlers.reparam(config={"theta": TransformReparam()}):
            theta = numpyro.sample(
                "theta",
                dist.TransformedDistribution(
                    dist.Normal(0.0, 1.0),
                    dist.transforms.AffineTransform(mu, tau)),
            )

        numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
Example #3
0
    def guide(X: DeviceArray):
        n_stores, n_days, n_features = X.shape
        n_features -= 1  # remove one dim for target

        plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
        plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)

        disp_param_mu = numpyro.sample(
            Site.disp_param_mu,
            dist.Normal(
                loc=model_params[Param.loc_disp_param_mu],
                scale=model_params[Param.scale_disp_param_mu],
            ),
        )

        disp_param_sigma = numpyro.sample(
            Site.disp_param_sigma,
            dist.TransformedDistribution(
                dist.Normal(
                    loc=model_params[Param.loc_disp_param_logsigma],
                    scale=model_params[Param.scale_disp_param_logsigma],
                ),
                transforms=dist.transforms.ExpTransform(),
            ),
        )

        with plate_stores:
            with numpyro.handlers.reparam(
                    config={Site.disp_params: TransformReparam()}):
                numpyro.sample(
                    Site.disp_params,
                    dist.TransformedDistribution(
                        dist.Normal(
                            loc=numpyro.param(Param.loc_disp_params,
                                              jnp.zeros((n_stores, 1))),
                            scale=numpyro.param(
                                Param.scale_disp_params,
                                0.1 * jnp.ones((n_stores, 1)),
                                constraint=dist.constraints.positive,
                            ),
                        ),
                        dist.transforms.AffineTransform(
                            disp_param_mu, disp_param_sigma),
                    ),
                )

        with plate_features:
            coef_mus = numpyro.sample(
                Site.coef_mus,
                dist.Normal(
                    loc=model_params[Param.loc_coef_mus],
                    scale=model_params[Param.scale_coef_mus],
                ),
            )
            coef_sigmas = numpyro.sample(
                Site.coef_sigmas,
                dist.TransformedDistribution(
                    dist.Normal(
                        loc=model_params[Param.loc_coef_logsigmas],
                        scale=model_params[Param.scale_coef_logsigmas],
                    ),
                    transforms=dist.transforms.ExpTransform(),
                ),
            )

            with plate_stores:
                with numpyro.handlers.reparam(
                        config={Site.coefs: TransformReparam()}):
                    numpyro.sample(
                        Site.coefs,
                        dist.TransformedDistribution(
                            dist.Normal(
                                loc=numpyro.param(
                                    Param.loc_coefs,
                                    jnp.zeros((n_stores, n_features))),
                                scale=numpyro.param(
                                    Param.scale_coefs,
                                    0.5 * jnp.ones((n_stores, n_features)),
                                    constraint=dist.constraints.positive,
                                ),
                            ),
                            dist.transforms.AffineTransform(
                                coef_mus, coef_sigmas),
                        ),
                    )
Example #4
0
    def model(X: DeviceArray):
        n_stores, n_days, n_features = X.shape
        n_features -= 1  # remove one dim for target

        plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
        plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
        plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

        disp_param_mu = numpyro.sample(
            Site.disp_param_mu,
            dist.Normal(
                loc=model_params[Param.loc_disp_param_mu],
                scale=model_params[Param.scale_disp_param_mu],
            ),
        )
        disp_param_sigma = numpyro.sample(
            Site.disp_param_sigma,
            dist.TransformedDistribution(
                dist.Normal(
                    loc=model_params[Param.loc_disp_param_logsigma],
                    scale=model_params[Param.scale_disp_param_logsigma],
                ),
                transforms=dist.transforms.ExpTransform(),
            ),
        )

        with plate_stores:
            with numpyro.handlers.reparam(
                    config={Site.disp_params: TransformReparam()}):
                disp_params = numpyro.sample(
                    Site.disp_params,
                    dist.TransformedDistribution(
                        dist.Normal(
                            loc=model_params[Param.loc_disp_params],
                            scale=model_params[Param.scale_disp_params],
                        ),
                        dist.transforms.AffineTransform(
                            disp_param_mu, disp_param_sigma),
                    ),
                )

        with plate_features:
            coef_mus = numpyro.sample(
                Site.coef_mus,
                dist.Normal(
                    loc=model_params[Param.loc_coef_mus],
                    scale=model_params[Param.scale_coef_mus],
                ),
            )
            coef_sigmas = numpyro.sample(
                Site.coef_sigmas,
                dist.TransformedDistribution(
                    dist.Normal(
                        loc=model_params[Param.loc_coef_logsigmas],
                        scale=model_params[Param.scale_coef_logsigmas],
                    ),
                    transforms=dist.transforms.ExpTransform(),
                ),
            )

            with plate_stores:
                with numpyro.handlers.reparam(
                        config={Site.coefs: TransformReparam()}):
                    coefs = numpyro.sample(
                        Site.coefs,
                        dist.TransformedDistribution(
                            dist.Normal(
                                loc=model_params[Param.loc_coefs],
                                scale=model_params[Param.scale_coefs],
                            ),
                            dist.transforms.AffineTransform(
                                coef_mus, coef_sigmas),
                        ),
                    )

        with plate_days, plate_stores:
            features = jnp.nan_to_num(X[..., :-1])
            means = jnp.exp(
                jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2))
            betas = jnp.exp(-disp_params)
            alphas = means * betas
            return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas))