Exemplo n.º 1
0
 def guide(difficulty=0.0):
     previous_sample = None
     for k in reversed(range(1, N + 1)):
         loc_q = numpyro.param(
             f"loc_q_{k}",
             lambda key: target_mus[k] + difficulty *
             (0.1 * random.normal(key) - 0.53),
         )
         log_sig_q = numpyro.param(
             f"log_sig_q_{k}",
             lambda key: -0.5 * jnp.log(lambda_posts[k]) + difficulty *
             (0.1 * random.normal(key) - 0.53),
         )
         sig_q = jnp.exp(log_sig_q)
         kappa_q = None
         if k != N:
             kappa_q = numpyro.param(
                 "kappa_q_%d" % k,
                 lambda key: target_kappas[k] + difficulty *
                 (0.1 * random.normal(key) - 0.53),
             )
         mean_function = loc_q if k == N else kappa_q * previous_sample + loc_q
         node_flagged = True if which_nodes_reparam[k - 1] == 1.0 else False
         Normal = dist.Normal if node_flagged else FakeNormal
         loc_latent = numpyro.sample(f"loc_latent_{k}",
                                     Normal(mean_function, sig_q))
         previous_sample = loc_latent
     return previous_sample
Exemplo n.º 2
0
 def _get_posterior(self):
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix),
                                jnp.identity(self.latent_dim) *
                                self._init_scale,
                                constraint=constraints.lower_cholesky)
     return dist.MultivariateNormal(loc, scale_tril=scale_tril)
Exemplo n.º 3
0
def flax_module(name, nn_module, *, input_shape=None):
    """
    Declare a :mod:`~flax` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    :param str name: name of the module to be registered.
    :param flax.nn.Module nn_module: a `flax` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import flax  # noqa: F401
    except ImportError as e:
        raise ImportError("Looking like you want to use flax to declare "
                          "nn modules. This is an experimental feature. "
                          "You need to install `flax` to be able to use this feature. "
                          "It can be installed with `pip install flax`.") from e
    module_key = name + '$params'
    nn_params = numpyro.param(module_key)
    if nn_params is None:
        if input_shape is None:
            raise ValueError('Valid value for `input_shape` needed to initialize.')
        # feed in dummy data to init params
        rng_key = numpyro.prng_key()
        _, nn_params = nn_module.init(rng_key, jnp.ones(input_shape))
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)
    return partial(nn_module.call, nn_params)
Exemplo n.º 4
0
 def _get_transform(self):
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix),
                                np.identity(self.latent_size) *
                                self._init_scale,
                                constraint=constraints.lower_cholesky)
     return MultivariateAffineTransform(loc, scale_tril)
Exemplo n.º 5
0
def haiku_module(name, nn, input_shape=None):
    """
    Declare a :mod:`~haiku` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    :param str name: name of the module to be registered.
    :param haiku.Module nn: a `haiku` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import haiku  # noqa: F401
    except ImportError:
        raise ImportError("Looking like you want to use haiku to declare "
                          "nn modules. This is an experimental feature. "
                          "You need to install `haiku` to be able to use this feature. "
                          "It can be installed with `pip install git+https://github.com/deepmind/dm-haiku`.")

    module_key = name + '$params'
    nn_params = numpyro.param(module_key)
    if nn_params is None:
        if input_shape is None:
            raise ValueError('Valid value for `input_shape` needed to initialize.')
        # feed in dummy data to init params
        rng_key = numpyro.sample(name + '$rng_key', PRNGIdentity())
        nn_params = nn.init(rng_key, jnp.ones(input_shape))
        numpyro.param(module_key, nn_params)
    return partial(nn.apply, nn_params, None)
Exemplo n.º 6
0
 def guide():
     loc = numpyro.param("loc", np.zeros(()))
     scale = numpyro.param("scale", np.ones(()), constraint=constraints.positive)
     x = numpyro.sample("x", dist.Normal(loc, scale))
     with numpyro.plate("plate", len(data)):
         with handlers.mask(mask=np.invert(mask)):
             numpyro.sample("y_unobserved", dist.Normal(x, 1.0))
Exemplo n.º 7
0
 def guide():
     loc = numpyro.param("loc", np.zeros(3))
     cov = numpyro.param("cov", np.eye(3), constraint=constraints.positive_definite)
     x = numpyro.sample("x", dist.MultivariateNormal(loc, cov))
     with numpyro.plate("plate", len(data)):
         with handlers.mask(mask=np.invert(mask)):
             numpyro.sample("y_unobserved", dist.MultivariateNormal(x, np.eye(3)))
Exemplo n.º 8
0
 def guide():
     alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17)
     beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143)
     alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log)
     numpyro.sample("lambda_latent", FakeGamma(alpha_q, beta_q))
     with numpyro.plate("data", len(data)):
         pass
Exemplo n.º 9
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.items():
            if site["type"] != "sample" or isinstance(
                    site["fn"], dist.PRNGIdentity) or site["is_observed"]:
                continue

            event_dim = self._event_dims[name]
            init_loc = self._init_locs[name]
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    stack.enter_context(plates[frame.name])

                site_loc = numpyro.param("{}_{}_loc".format(name, self.prefix),
                                         init_loc,
                                         event_dim=event_dim)
                site_scale = numpyro.param("{}_{}_scale".format(
                    name, self.prefix),
                                           jnp.full(jnp.shape(init_loc),
                                                    self._init_scale),
                                           constraint=constraints.positive,
                                           event_dim=event_dim)

                site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
                if site["fn"].support in [
                        constraints.real, constraints.real_vector
                ]:
                    result[name] = numpyro.sample(name, site_fn)
                else:
                    unconstrained_value = numpyro.sample(
                        "{}_unconstrained".format(name),
                        site_fn,
                        infer={"is_auxiliary": True})

                    transform = biject_to(site['fn'].support)
                    value = transform(unconstrained_value)
                    log_density = -transform.log_abs_det_jacobian(
                        unconstrained_value, value)
                    log_density = sum_rightmost(
                        log_density,
                        jnp.ndim(log_density) - jnp.ndim(value) +
                        site["fn"].event_dim)
                    delta_dist = dist.Delta(value,
                                            log_density=log_density,
                                            event_dim=site["fn"].event_dim)
                    result[name] = numpyro.sample(name, delta_dist)

        return result
Exemplo n.º 10
0
 def guide(data):
     alpha_q = numpyro.param("alpha_q",
                             lambda key: random.normal(key),
                             constraint=constraints.positive)
     beta_q = numpyro.param("beta_q",
                            lambda key: random.exponential(key),
                            constraint=constraints.positive)
     numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
Exemplo n.º 11
0
 def _get_posterior(self):
     loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent)
     scale = numpyro.param(
         "{}_scale".format(self.prefix),
         jnp.full(self.latent_dim, self._init_scale),
         constraint=self.scale_constraint,
     )
     return dist.Normal(loc, scale)
Exemplo n.º 12
0
 def _get_posterior(self):
     loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent)
     scale_tril = numpyro.param(
         "{}_scale_tril".format(self.prefix),
         jnp.identity(self.latent_dim) * self._init_scale,
         constraint=self.scale_tril_constraint,
     )
     return dist.MultivariateNormal(loc, scale_tril=scale_tril)
Exemplo n.º 13
0
 def guide():
     alpha_q_log = numpyro.param("alpha_q_log", log_alpha_n + 0.17)
     beta_q_log = numpyro.param("beta_q_log", log_beta_n - 0.143)
     alpha_q, beta_q = jnp.exp(alpha_q_log), jnp.exp(beta_q_log)
     p_latent = numpyro.sample("p_latent", FakeBeta(alpha_q, beta_q))
     with numpyro.plate("data", len(data)):
         pass
     return p_latent
Exemplo n.º 14
0
 def guide():
     loc_q = numpyro.param("loc_q",
                           analytic_loc_n + jnp.array([0.334, 0.334]))
     log_sig_q = numpyro.param(
         "log_sig_q", analytic_log_sig_n + jnp.array([-0.29, -0.29]))
     sig_q = jnp.exp(log_sig_q)
     with numpyro.plate("plate", 2):
         loc_latent = numpyro.sample("loc_latent", FakeNormal(loc_q, sig_q))
     return loc_latent
Exemplo n.º 15
0
    def model():
        loc1 = numpyro.param("loc1", 0.)
        scale1 = numpyro.param("scale1", 1., constraint=constraints.positive)
        numpyro.sample("latent1", dist.Normal(loc1, scale1))

        loc2 = numpyro.param("loc2", 1.)
        scale2 = numpyro.param("scale2", 2., constraint=constraints.positive)
        latent2 = numpyro.sample("latent2", dist.Normal(loc2, scale2))
        return latent2
Exemplo n.º 16
0
def test_subsample_param():
    data = jnp.arange(100.)
    subsample_size = 7
    with handlers.seed(rng_seed=0):
        with numpyro.plate("a", len(data), subsample_size=subsample_size):
            p0 = numpyro.param("p0", 0., event_dim=0)
            assert jnp.shape(p0) == ()
            p = numpyro.param("p", 0.5 * jnp.ones(len(data)), event_dim=0)
            assert len(p) == subsample_size
Exemplo n.º 17
0
 def model(z1=None, z2=None):
     p = numpyro.param("p", np.array([0.25, 0.75]))
     loc = numpyro.param("loc", jnp.array([-1.0, 1.0]))
     z1 = numpyro.sample("z1", dist.Categorical(p), obs=z1)
     with numpyro.plate("data[0]", 3):
         numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0])
     with numpyro.plate("data[1]", 2):
         z2 = numpyro.sample("z2", dist.Categorical(p), obs=z2)
         numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
Exemplo n.º 18
0
def gmm_guide(data, num_components=3):
    mus_val = numpyro.param('mus_val', jnp.array(stats.norm.rvs(size=num_components) * 1000),
                            constraint=dist.constraints.real)
    sigmas_val = numpyro.param('sigmas_val', jnp.ones(num_components), constraint=dist.constraints.positive)
    mus = numpyro.sample('mus', dist.Delta(mus_val))
    sigmas = numpyro.sample('sigmas', dist.Delta(sigmas_val))
    mixture_probs_val = numpyro.param('mixture_probs_val',
                                      jax.nn.softmax(stats.norm.rvs(size=num_components)),
                                      constraint=dist.constraints.simplex)
    mixture_probs = numpyro.sample('mixture_probs', dist.Delta(mixture_probs_val))
Exemplo n.º 19
0
 def _get_posterior(self, *args, **kwargs):
     rank = int(round(self.latent_dim ** 0.5)) if self.rank is None else self.rank
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), jnp.zeros((self.latent_dim, rank)))
     scale = numpyro.param('{}_scale'.format(self.prefix),
                           jnp.full(self.latent_dim, self._init_scale),
                           constraint=constraints.positive)
     cov_diag = scale * scale
     cov_factor = cov_factor * scale[..., None]
     return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
Exemplo n.º 20
0
 def _sample_latent(self, base_dist, *args, **kwargs):
     sample_shape = kwargs.pop('sample_shape', ())
     rank = int(round(self.latent_size ** 0.5)) if self.rank is None else self.rank
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     cov_factor = numpyro.param('{}_cov_factor'.format(self.prefix), np.zeros((self.latent_size, rank)))
     scale = numpyro.param('{}_scale'.format(self.prefix), np.ones(self.latent_size))
     cov_diag = scale * scale
     cov_factor = cov_factor * scale[..., None]
     posterior = dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
     return numpyro.sample("_{}_latent".format(self.prefix), posterior, sample_shape=sample_shape)
Exemplo n.º 21
0
 def model(z1=None, z2=None):
     p = numpyro.param("p", jnp.array([[0.25, 0.75], [0.1, 0.9]]))
     loc = numpyro.param("loc", jnp.array([-1.0, 1.0]))
     z1 = numpyro.sample("z1", dist.Categorical(p[0]), obs=z1)
     z2 = numpyro.sample("z2", dist.Categorical(p[z1]), obs=z2)
     logger.info("z1.shape = {}".format(z1.shape))
     logger.info("z2.shape = {}".format(z2.shape))
     with numpyro.plate("data", 3):
         numpyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0])
         numpyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])
Exemplo n.º 22
0
 def model():
     p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
     q = numpyro.param("q", 0.25 * jnp.ones(2))
     z = numpyro.sample("z", dist.Bernoulli(0.5))
     x_prev = 0
     x_curr = 0
     for t in markov(range(T), history=history):
         probs = p[x_prev, x_curr, z]
         x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t),
                                                 dist.Bernoulli(probs))
         numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0)
     return x_prev, x_curr
Exemplo n.º 23
0
    def guide(X: DeviceArray):
        n_stores, n_days, n_features = X.shape
        n_features -= 1  # remove one dim for target

        plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
        plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)

        numpyro.sample(
            Site.disp_param_mu,
            dist.Normal(loc=model_params[Param.loc_disp_param_mu],
                        scale=model_params[Param.scale_disp_param_mu]))

        numpyro.sample(
            Site.disp_param_sigma,
            dist.TransformedDistribution(
                dist.Normal(
                    loc=model_params[Param.loc_disp_param_logsigma],
                    scale=model_params[Param.scale_disp_param_logsigma]),
                transforms=dist.transforms.ExpTransform()))

        with plate_stores:
            numpyro.sample(
                Site.disp_param_offsets,
                dist.Normal(loc=numpyro.param(Param.loc_disp_param_offsets,
                                              jnp.zeros((n_stores, 1))),
                            scale=numpyro.param(
                                Param.scale_disp_param_offsets,
                                0.1 * jnp.ones((n_stores, 1)),
                                constraint=dist.constraints.positive)))

        with plate_features:
            numpyro.sample(
                Site.coef_mus,
                dist.Normal(loc=model_params[Param.loc_coef_mus],
                            scale=model_params[Param.scale_coef_mus]))
            numpyro.sample(
                Site.coef_sigmas,
                dist.TransformedDistribution(
                    dist.Normal(
                        loc=model_params[Param.loc_coef_logsigmas],
                        scale=model_params[Param.scale_coef_logsigmas]),
                    transforms=dist.transforms.ExpTransform()))

            with plate_stores:
                numpyro.sample(
                    Site.coef_offsets,
                    dist.Normal(loc=numpyro.param(
                        Param.loc_coef_offsets,
                        jnp.zeros((n_stores, n_features))),
                                scale=numpyro.param(
                                    Param.scale_coef_offsets,
                                    0.5 * jnp.ones((n_stores, n_features)),
                                    constraint=dist.constraints.positive)))
Exemplo n.º 24
0
    def guide():
        m1 = numpyro.param("m1", 2.0)
        s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive)
        m2 = numpyro.param("m2", 2.0)
        s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive)

        def true_fun(_):
            numpyro.sample("x", dist.Normal(m1, s1))

        def false_fun(_):
            numpyro.sample("x", dist.Normal(m2, s2))

        cluster = numpyro.sample("cluster", dist.Normal())
        cond(cluster > 0, true_fun, false_fun, None)
Exemplo n.º 25
0
def guide(
    x: Optional[jnp.ndarray] = None,
    seq_len: int = 0,
    batch: int = 0,
    x_dim: int = 1,
    future_steps: int = 0,
) -> None:

    if x is not None:
        *_, x_dim = x.shape

    phi = numpyro.param("phi", jnp.ones(x_dim))
    sigma = numpyro.param("sigma", jnp.ones(x_dim) * 0.05, constraint=constraints.positive)
    numpyro.sample("z", dist.Normal(x * phi, sigma))
Exemplo n.º 26
0
    def guide(self):
        if self.fit_rho:
            rho_loc = npy.param(
                Sites.RHO + Sites.LOC,
                jnp.tile(self.rho_loc, (self.num_ltla, 1)),
            )
            rho_scale = npy.param(
                Sites.RHO + Sites.SCALE,
                jnp.tile(self.init_scale * self.rho_scale, (self.num_ltla, 1)),
                constraint=dist.constraints.positive,
            )
            npy.sample(Sites.RHO, dist.Normal(rho_loc, rho_scale))

        # mean / sd for parameter s
        beta_loc = npy.param(
            Sites.BETA + Sites.LOC,
            jnp.tile(self.beta_loc, (self.num_ltla_lin, self.num_basis)),
        )
        beta_scale = npy.param(
            Sites.BETA + Sites.SCALE,
            self.init_scale * self.beta_scale *
            jnp.stack(self.num_ltla_lin * [jnp.eye(self.num_basis)]),
            constraint=dist.constraints.lower_cholesky,
        )

        npy.sample(Sites.BETA,
                   dist.MultivariateNormal(beta_loc, scale_tril=beta_scale))

        b0_loc = npy.param(
            Sites.BC0 + Sites.LOC,
            jnp.concatenate([
                jnp.repeat(self.b0_loc, self.num_lin),
            ]),
        )
        b0_scale = npy.param(
            Sites.BC0 + Sites.SCALE,
            jnp.diag(
                jnp.concatenate([
                    jnp.repeat(
                        self.init_scale * self.b0_scale * self.time_scale,
                        self.num_lin,
                    ),
                ])),
            constraint=dist.constraints.lower_cholesky,
        )
        npy.sample(Sites.B0,
                   dist.MultivariateNormal(b0_loc, scale_tril=b0_scale))

        c_loc = npy.param(
            Sites.C + Sites.LOC,
            jnp.tile(self.c_loc, (self.num_ltla_lin, self.num_lin)))

        c_scale = npy.param(
            Sites.C + Sites.SCALE,
            jnp.tile(self.init_scale * self.c_scale,
                     (self.num_ltla_lin, self.num_lin)),
        )
        npy.sample(Sites.C, dist.Normal(c_loc, c_scale))
Exemplo n.º 27
0
    def __call__(self, *args, **kwargs):
        if self.prototype_trace is None:
            # run model to inspect the model structure
            self._setup_prototype(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue

            event_dim = self._event_dims[name]
            init_loc = self._init_locs[name]
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    stack.enter_context(plates[frame.name])

                site_loc = numpyro.param(
                    "{}_{}_loc".format(name, self.prefix),
                    init_loc,
                    constraint=site["fn"].support,
                    event_dim=event_dim,
                )

                site_fn = dist.Delta(site_loc).to_event(event_dim)
                result[name] = numpyro.sample(name, site_fn)

        return result
Exemplo n.º 28
0
    def __call__(self, name, fn, obs):
        assert obs is None, "LocScaleReparam does not support observe statements"
        centered = self.centered
        if is_identically_one(centered):
            return name, fn, obs
        event_shape = fn.event_shape
        fn, batch_shape, event_dim = self._unwrap(fn)

        # Apply a partial decentering transform.
        params = {key: getattr(fn, key) for key in self.shape_params}
        if self.centered is None:
            centered = numpyro.param("{}_centered".format(name),
                                     jnp.full(event_shape, 0.5),
                                     constraint=constraints.unit_interval)
        params["loc"] = fn.loc * centered
        params["scale"] = fn.scale**centered
        decentered_fn = self._wrap(type(fn)(**params), batch_shape, event_dim)

        # Draw decentered noise.
        decentered_value = numpyro.sample("{}_decentered".format(name),
                                          decentered_fn)

        # Differentiably transform.
        delta = decentered_value - centered * fn.loc
        value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta

        # Simulate a pyro.deterministic() site.
        return None, value
Exemplo n.º 29
0
    def fun_model():
        p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
        q = numpyro.param("q", 0.25 * jnp.ones(2))
        z = numpyro.sample("z", dist.Bernoulli(0.5))

        def transition_fn(carry, y):
            x_prev, x_curr = carry
            probs = p[x_prev, x_curr, z]
            x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs))
            numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y)
            return (x_prev, x_curr), None

        (x_prev, x_curr), _ = scan(transition_fn, (0, 0),
                                   jnp.zeros(T),
                                   history=history)
        return x_prev, x_curr
Exemplo n.º 30
0
 def model(z=None):
     p = numpyro.param("p", np.array([0.75, 0.25]))
     iz = numpyro.sample("z", dist.Categorical(p), obs=z)
     z = jnp.array([0.0, 1.0])[iz]
     logger.info("z.shape = {}".format(z.shape))
     with numpyro.plate("data", 3):
         numpyro.sample("x", dist.Normal(z, 1.0), obs=data)