Example #1
0
def test_scan_enum_one_latent(num_steps):
    data = random.normal(random.PRNGKey(0), (num_steps, ))
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        x = None
        for i, y in markov(enumerate(data)):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)
        return x

    def fun_model(data):
        def transition_fn(x, y):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, 1

        x, collections = scan(transition_fn, None, data)
        assert collections.shape == data.shape[:1]
        return x

    expected_log_joint = log_density(enum(config_enumerate(model)), (data, ),
                                     {}, {})[0]
    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data, ),
                                   {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)

    actual_last_x = enum(config_enumerate(fun_model))(data)
    expected_last_x = enum(config_enumerate(model))(data)
    assert_allclose(actual_last_x, expected_last_x)
Example #2
0
def test_scan_enum_plate():
    N, D = 10, 3
    data = random.normal(random.PRNGKey(0), (N, D))
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        x = None
        D_plate = numpyro.plate("D", D, dim=-1)
        for i, y in markov(enumerate(data)):
            with D_plate:
                probs = init_probs if x is None else transition_probs[x]
                x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
                numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)

    def fun_model(data):
        def transition_fn(x, y):
            probs = init_probs if x is None else transition_probs[x]
            with numpyro.plate("D", D, dim=-1):
                x = numpyro.sample("x", dist.Categorical(probs))
                numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(transition_fn, None, data)

    expected_log_joint = log_density(enum(config_enumerate(model), -2),
                                     (data, ), {}, {})[0]
    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2),
                                   (data, ), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #3
0
def test_scan_enum_two_latents():
    num_steps = 11
    data = random.normal(random.PRNGKey(0), (num_steps, ))
    probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]])
    locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]])

    def model(data):
        x = w = 0
        for i, y in markov(enumerate(data)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_{i}", dist.Normal(locs[w, x], 1), obs=y)

    def fun_model(data):
        def transition_fn(carry, y):
            x, w = carry
            x = numpyro.sample("x", dist.Categorical(probs_x[x]))
            w = numpyro.sample("w", dist.Categorical(probs_w[w]))
            numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y)
            # also test if scan's `ys` are recorded corrected
            return (x, w), x

        scan(transition_fn, (0, 0), data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data, ),
                                   {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data, ),
                                     {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #4
0
def test_scan_enum_scan_enum():
    num_steps = 11
    data_x = random.normal(random.PRNGKey(0), (num_steps, ))
    data_w = data_x[:-1] + 1
    probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]])
    locs_x = jnp.array([-1.0, 1.0])
    locs_w = jnp.array([2.0, 3.0])

    def model(data_x, data_w):
        x = w = 0
        for i, y in markov(enumerate(data_x)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y)

        for i, y in markov(enumerate(data_w)):
            w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y)

    def fun_model(data_x, data_w):
        def transition_fn(name, probs, locs, x, y):
            x = numpyro.sample(name, dist.Categorical(probs[x]))
            numpyro.sample("y_" + name, dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(partial(transition_fn, "x", probs_x, locs_x), 0, data_x)
        scan(partial(transition_fn, "w", probs_w, locs_w), 0, data_w)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)),
                                   (data_x, data_w), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)),
                                     (data_x, data_w), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #5
0
def test_scan_enum_discrete_outside():
    data = random.normal(random.PRNGKey(0), (10, ))
    probs = jnp.array([[[0.8, 0.2], [0.1, 0.9]], [[0.7, 0.3], [0.6, 0.4]]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        w = numpyro.sample("w", dist.Bernoulli(0.6))
        x = 0
        for i, y in markov(enumerate(data)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs[w, x]))
            numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)

    def fun_model(data):
        w = numpyro.sample("w", dist.Bernoulli(0.6))

        def transition_fn(x, y):
            x = numpyro.sample("x", dist.Categorical(probs[w, x]))
            numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(transition_fn, 0, data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data, ),
                                   {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data, ),
                                     {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #6
0
def test_scan_enum_separated_plate_discrete():
    N, D = 10, 3
    data = random.normal(random.PRNGKey(0), (N, D))
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]])

    def model(data):
        x = 0
        D_plate = numpyro.plate("D", D, dim=-1)
        for i, y in markov(enumerate(data)):
            probs = transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            with D_plate:
                w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6))
                numpyro.sample(f"y_{i}",
                               dist.Normal(Vindex(locs)[x, w], 1),
                               obs=y)

    def fun_model(data):
        def transition_fn(x, y):
            probs = transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            with numpyro.plate("D", D, dim=-1):
                w = numpyro.sample("w", dist.Bernoulli(0.6))
                numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
            return x, None

        scan(transition_fn, 0, data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2),
                                   (data, ), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model), -2),
                                     (data, ), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #7
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
Example #8
0
def test_nested_plate():
    with enum(first_available_dim=-3):
        with enum_plate("a", 5):
            with enum_plate("b", 2):
                x = numpyro.sample("x",
                                   dist.Normal(0, 1),
                                   rng_key=random.PRNGKey(0))
                assert x.shape == (2, 5)
Example #9
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        # convert to unconstrained values
        z_hmc = {
            k: biject_to(prototype_trace[k]["fn"].support).inv(v)
            for k, v in hmc_sites.items()
            if k in prototype_trace and prototype_trace[k]["type"] == "sample"
        }
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        wrapped_model = _wrap_model(model)
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model = enum(config_enumerate(wrapped_model),
                                 -max_plate_nesting - 1)

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
Example #10
0
def test_scan_history(history, T):
    def model():
        p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
        q = numpyro.param("q", 0.25 * jnp.ones(2))
        z = numpyro.sample("z", dist.Bernoulli(0.5))
        x_prev = 0
        x_curr = 0
        for t in markov(range(T), history=history):
            probs = p[x_prev, x_curr, z]
            x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t),
                                                    dist.Bernoulli(probs))
            numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0)
        return x_prev, x_curr

    def fun_model():
        p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
        q = numpyro.param("q", 0.25 * jnp.ones(2))
        z = numpyro.sample("z", dist.Bernoulli(0.5))

        def transition_fn(carry, y):
            x_prev, x_curr = carry
            probs = p[x_prev, x_curr, z]
            x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs))
            numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y)
            return (x_prev, x_curr), None

        (x_prev, x_curr), _ = scan(transition_fn, (0, 0),
                                   jnp.zeros(T),
                                   history=history)
        return x_prev, x_curr

    expected_log_joint = log_density(enum(config_enumerate(model)), (), {},
                                     {})[0]
    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (), {},
                                   {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)

    expected_x_prev, expected_x_curr = enum(config_enumerate(model))()
    actual_x_prev, actual_x_curr = enum(config_enumerate(fun_model))()
    assert_allclose(actual_x_prev, expected_x_prev)
    assert_allclose(actual_x_curr, expected_x_curr)
Example #11
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        z_hmc = hmc_sites
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model_ = enum(config_enumerate(wrapped_model),
                                  -max_plate_nesting - 1)
        else:
            wrapped_model_ = wrapped_model

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model_,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
Example #12
0
def test_scan_enum_separated_plates_same_dim():
    N, D1, D2 = 10, 3, 4
    data = random.normal(random.PRNGKey(0), (N, D1 + D2))
    data1, data2 = data[:, :D1], data[:, D1:]
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data1, data2):
        x = None
        D1_plate = numpyro.plate("D1", D1, dim=-1)
        D2_plate = numpyro.plate("D2", D2, dim=-1)
        for i, (y1, y2) in markov(enumerate(zip(data1, data2))):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            with D1_plate:
                numpyro.sample(f"y1_{i}", dist.Normal(locs[x], 1), obs=y1)
            with D2_plate:
                numpyro.sample(f"y2_{i}", dist.Normal(locs[x], 1), obs=y2)

    def fun_model(data1, data2):
        def transition_fn(x, y):
            y1, y2 = y
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            with numpyro.plate("D1", D1, dim=-1):
                numpyro.sample("y1", dist.Normal(locs[x], 1), obs=y1)
            with numpyro.plate("D2", D2, dim=-1):
                numpyro.sample("y2", dist.Normal(locs[x], 1), obs=y2)
            return x, None

        scan(transition_fn, None, (data1, data2))

    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2),
                                   (data1, data2), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model), -2),
                                     (data1, data2), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #13
0
def main(_args):
    data = generate_data()
    init_rng_key = PRNGKey(1273)
    # nuts = NUTS(gmm)
    # mcmc = MCMC(nuts, 100, 1000)
    # mcmc.print_summary()
    seeded_gmm = seed(gmm, init_rng_key)
    model_trace = trace(seeded_gmm).get_trace(data)
    max_plate_nesting = _guess_max_plate_nesting(model_trace)
    enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1)
    svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.))
    svi_state = svi.init(init_rng_key, data)
    upd_fun = jax.jit(svi.update)
    with tqdm.trange(100_000) as pbar:
        for i in pbar:
            svi_state, loss = upd_fun(svi_state, data)
            pbar.set_description(f"SVI {loss}", True)
Example #14
0
    def run(self, rng_key, *args, **kwargs):
        """
        Run the nested samplers and collect weighted samples.

        :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
        :param args: The arguments needed by the `model`.
        :param kwargs: The keyword arguments needed by the `model`.
        """
        rng_sampling, rng_predictive = random.split(rng_key)
        # reparam the model so that latent sites have Uniform(0, 1) priors
        prototype_trace = trace(seed(self.model,
                                     rng_key)).get_trace(*args, **kwargs)
        param_names = [
            site["name"] for site in prototype_trace.values()
            if site["type"] == "sample" and not site["is_observed"]
            and site["infer"].get("enumerate", "") != "parallel"
        ]
        deterministics = [
            site["name"] for site in prototype_trace.values()
            if site["type"] == "deterministic"
        ]
        reparam_model = reparam(
            self.model, config={k: UniformReparam()
                                for k in param_names})

        # enable enumerate if needed
        has_enum = any(site["type"] == "sample"
                       and site["infer"].get("enumerate", "") == "parallel"
                       for site in prototype_trace.values())
        if has_enum:
            from numpyro.contrib.funsor import enum, log_density as log_density_

            max_plate_nesting = _guess_max_plate_nesting(prototype_trace)
            _validate_model(prototype_trace)
            reparam_model = enum(reparam_model, -max_plate_nesting - 1)
        else:
            log_density_ = log_density

        def loglik_fn(**params):
            return log_density_(reparam_model, args, kwargs, params)[0]

        # use NestedSampler with identity prior chain
        prior_chain = PriorChain()
        for name in param_names:
            prior = UniformPrior(name + "_base",
                                 prototype_trace[name]["fn"].shape())
            prior_chain.push(prior)
        # XXX: the `marginalised` keyword in jaxns can be used to get expectation of some
        # quantity over posterior samples; it can be helpful to expose it in this wrapper
        ns = OrigNestedSampler(
            loglik_fn,
            prior_chain,
            sampler_name=self.sampler_name,
            sampler_kwargs={
                "depth": self.depth,
                "num_slices": self.num_slices
            },
            max_samples=self.max_samples,
            num_live_points=self.num_live_points,
            collect_samples=True,
        )
        # some places of jaxns uses float64 and raises some warnings if the default dtype is
        # float32, so we suppress them here to avoid confusion
        with warnings.catch_warnings():
            warnings.filterwarnings(
                "ignore", message=".*will be truncated to dtype float32.*")
            results = ns(rng_sampling, termination_frac=self.termination_frac)
        # transform base samples back to original domains
        # Here we only transform the first valid num_samples samples
        # NB: the number of weighted samples obtained from jaxns is results.num_samples
        # and only the first num_samples values of results.samples are valid.
        num_samples = results.num_samples
        samples = tree_util.tree_map(lambda x: x[:num_samples],
                                     results.samples)
        predictive = Predictive(reparam_model,
                                samples,
                                return_sites=param_names + deterministics)
        samples = predictive(rng_predictive, *args, **kwargs)
        # replace base samples in jaxns results by transformed samples
        self._results = results._replace(samples=samples)
Example #15
0
def initialize_model(rng_key,
                     model,
                     init_strategy=init_to_uniform,
                     dynamic_args=False,
                     model_args=(),
                     model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: a namedtupe `ModelInfo` which contains the fields
        (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
        `param_info` is a namedtuple `ParamInfo` containing values from the prior
        used to initiate MCMC, their corresponding potential energy, and their gradients;
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    substituted_model = substitute(seed(
        model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
                                   substitute_fn=init_strategy)
    inv_transforms, replay_model, has_enumerate_support, model_trace = _get_model_transforms(
        substituted_model, model_args, model_kwargs)
    constrained_values = {
        k: v['value']
        for k, v in model_trace.items() if v['type'] == 'sample'
        and not v['is_observed'] and not v['fn'].is_discrete
    }

    if has_enumerate_support:
        from numpyro.contrib.funsor import config_enumerate, enum

        if not isinstance(model, enum):
            max_plate_nesting = _guess_max_plate_nesting(model_trace)
            model = enum(config_enumerate(model), -max_plate_nesting - 1)

    potential_fn, postprocess_fn = get_potential_fn(model,
                                                    inv_transforms,
                                                    replay_model=replay_model,
                                                    enum=has_enumerate_support,
                                                    dynamic_args=dynamic_args,
                                                    model_args=model_args,
                                                    model_kwargs=model_kwargs)

    init_strategy = init_strategy if isinstance(init_strategy,
                                                partial) else init_strategy()
    if (init_strategy.func is init_to_value) and not replay_model:
        init_values = init_strategy.keywords.get("values")
        unconstrained_values = transform_fn(inv_transforms,
                                            init_values,
                                            invert=True)
        init_strategy = _init_to_unconstrained_value(
            values=unconstrained_values)
    prototype_params = transform_fn(inv_transforms,
                                    constrained_values,
                                    invert=True)
    (init_params, pe, grad), is_valid = find_valid_initial_params(
        rng_key,
        model,
        init_strategy=init_strategy,
        enum=has_enumerate_support,
        model_args=model_args,
        model_kwargs=model_kwargs,
        prototype_params=prototype_params)

    if not_jax_tracer(is_valid):
        if device_get(~jnp.all(is_valid)):
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn,
                     postprocess_fn, model_trace)
Example #16
0
def initialize_model(
    rng_key,
    model,
    *,
    init_strategy=init_to_uniform,
    dynamic_args=False,
    model_args=(),
    model_kwargs=None,
    forward_mode_differentiation=False,
    validate_grad=True,
):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param bool forward_mode_differentiation: whether to use forward-mode differentiation
        or reverse-mode differentiation. By default, we use reverse mode but the forward
        mode can be useful in some cases to improve the performance. In addition, some
        control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
        only supports forward-mode differentiation. See
        `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
        for more information.
    :param bool validate_grad: whether to validate gradient of the initial params.
        Defaults to True.
    :return: a namedtupe `ModelInfo` which contains the fields
        (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
        `param_info` is a namedtuple `ParamInfo` containing values from the prior
        used to initiate MCMC, their corresponding potential energy, and their gradients;
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    substituted_model = substitute(
        seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
        substitute_fn=init_strategy,
    )
    (
        inv_transforms,
        replay_model,
        has_enumerate_support,
        model_trace,
    ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
    # substitute param sites from model_trace to model so
    # we don't need to generate again parameters of `numpyro.module`
    model = substitute(
        model,
        data={
            k: site["value"]
            for k, site in model_trace.items() if site["type"] in ["param"]
        },
    )
    constrained_values = {
        k: v["value"]
        for k, v in model_trace.items() if v["type"] == "sample"
        and not v["is_observed"] and not v["fn"].is_discrete
    }

    if has_enumerate_support:
        from numpyro.contrib.funsor import config_enumerate, enum

        if not isinstance(model, enum):
            max_plate_nesting = _guess_max_plate_nesting(model_trace)
            _validate_model(model_trace)
            model = enum(config_enumerate(model), -max_plate_nesting - 1)

    potential_fn, postprocess_fn = get_potential_fn(
        model,
        inv_transforms,
        replay_model=replay_model,
        enum=has_enumerate_support,
        dynamic_args=dynamic_args,
        model_args=model_args,
        model_kwargs=model_kwargs,
    )

    init_strategy = (init_strategy if isinstance(init_strategy, partial) else
                     init_strategy())
    if (init_strategy.func is init_to_value) and not replay_model:
        init_values = init_strategy.keywords.get("values")
        unconstrained_values = transform_fn(inv_transforms,
                                            init_values,
                                            invert=True)
        init_strategy = _init_to_unconstrained_value(
            values=unconstrained_values)
    prototype_params = transform_fn(inv_transforms,
                                    constrained_values,
                                    invert=True)
    (init_params, pe, grad), is_valid = find_valid_initial_params(
        rng_key,
        substitute(
            model,
            data={
                k: site["value"]
                for k, site in model_trace.items()
                if site["type"] in ["plate"]
            },
        ),
        init_strategy=init_strategy,
        enum=has_enumerate_support,
        model_args=model_args,
        model_kwargs=model_kwargs,
        prototype_params=prototype_params,
        forward_mode_differentiation=forward_mode_differentiation,
        validate_grad=validate_grad,
    )

    if not_jax_tracer(is_valid):
        if device_get(~jnp.all(is_valid)):
            with numpyro.validation_enabled(), trace() as tr:
                # validate parameters
                substituted_model(*model_args, **model_kwargs)
                # validate values
                for site in tr.values():
                    if site["type"] == "sample":
                        with warnings.catch_warnings(record=True) as ws:
                            site["fn"]._validate_sample(site["value"])
                        if len(ws) > 0:
                            for w in ws:
                                # at site information to the warning message
                                w.message.args = ("Site {}: {}".format(
                                    site["name"],
                                    w.message.args[0]), ) + w.message.args[1:]
                                warnings.showwarning(
                                    w.message,
                                    w.category,
                                    w.filename,
                                    w.lineno,
                                    file=w.file,
                                    line=w.line,
                                )
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn,
                     postprocess_fn, model_trace)
Example #17
0
def scan_enum(
    f,
    init,
    xs,
    length,
    reverse,
    rng_key=None,
    substitute_stack=None,
    history=1,
    first_available_dim=None,
):
    from numpyro.contrib.funsor import (
        config_enumerate,
        enum,
        markov,
        trace as packed_trace,
    )

    # amount number of steps to unroll
    history = min(history, length)
    unroll_steps = min(2 * history - 1, length)
    if reverse:
        x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs)
        xs_ = tree_map(lambda x: x[:-unroll_steps], xs)
    else:
        x0 = tree_map(lambda x: x[:unroll_steps], xs)
        xs_ = tree_map(lambda x: x[unroll_steps:], xs)

    carry_shapes = []

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i)
                        and i in range(unroll_steps)) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        # we need to tell unconstrained messenger in potential energy computation
        # that only the item at time `i` is needed when transforming
        fn = handlers.infer_config(
            f, config_fn=lambda msg: {"_scan_current_index": i})

        seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == "condition":
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == "substitute":
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            # handler the name to match the pattern of sakkar_bilmes product
            with handlers.scope(prefix="_PREV_" * (unroll_steps - i),
                                divider=""):
                new_carry, y = config_enumerate(seeded_fn)(carry, x)
                trace = {}
        else:
            # Like scan_wrapper, we collect the trace of scan's transition function
            # `seeded_fn` here. To put time dimension to the correct position, we need to
            # promote shapes to make `fn` and `value`
            # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
            # and value's batch_shape is (3,), then we promote shape of
            # value so that its batch shape is (1, 3)).
            # Here we will promote `fn` shape first. `value` shape will be promoted after scanned.
            # We don't promote `value` shape here because we need to store carry shape
            # at this step. If we reshape the `value` here, output carry might get wrong shape.
            with _promote_fn_shapes(), packed_trace() as trace:
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            if len(carry_shapes) < (history + 1):
                carry_shapes.append(
                    [jnp.shape(x) for x in tree_flatten(new_carry)[0]])
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

    with handlers.block(
            hide_fn=lambda site: not site["name"].startswith("_PREV_")), enum(
                first_available_dim=first_available_dim):
        wrapped_carry = (0, rng_key, init)
        y0s = []
        # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`
        for i in markov(range(unroll_steps + 1), history=history):
            if i < unroll_steps:
                wrapped_carry, (_, y0) = body_fn(wrapped_carry,
                                                 tree_map(lambda z: z[i], x0))
                if i > 0:
                    # reshape y1, y2,... to have the same shape as y0
                    y0 = tree_multimap(
                        lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0],
                        y0)
                y0s.append(y0)
                # shapes of the first `history - 1` steps are not useful to interpret the last carry
                # shape so we don't need to record them here
                if (i >= history - 1) and (len(carry_shapes) < history + 1):
                    carry_shapes.append(
                        jnp.shape(x)
                        for x in tree_flatten(wrapped_carry[-1])[0])
            else:
                # this is the last rolling step
                y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s)
                # return early if length = unroll_steps
                if length == unroll_steps:
                    return wrapped_carry, (PytreeTrace({}), y0s)
                wrapped_carry = device_put(wrapped_carry)
                wrapped_carry, (pytree_trace,
                                ys) = lax.scan(body_fn, wrapped_carry, xs_,
                                               length - unroll_steps, reverse)

    first_var = None
    for name, site in pytree_trace.trace.items():
        # currently, we only record sample or deterministic in the trace
        # we don't need to adjust `dim_to_name` for deterministic site
        if site["type"] not in ("sample", ):
            continue
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name

        # we haven't promote shapes of values yet during `lax.scan`, so we do it here
        site["value"] = _promote_scanned_value_shapes(site["value"],
                                                      site["fn"])

        # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
        # we don't record 1-size dimensions in this field
        time_dim = -min(len(site["fn"].batch_shape),
                        jnp.ndim(site["value"]) - site["fn"].event_dim)
        site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(
        lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys)
    # then join with y0s
    ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
    # we also need to reshape `carry` to match sequential behavior
    i = (length + 1) % (history + 1)
    t, rng_key, carry = wrapped_carry
    carry_shape = carry_shapes[i]
    flatten_carry, treedef = tree_flatten(carry)
    flatten_carry = [
        jnp.reshape(x, t1_shape)
        for x, t1_shape in zip(flatten_carry, carry_shape)
    ]
    carry = tree_unflatten(treedef, flatten_carry)
    wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)
Example #18
0
    def init(self, rng_key, *args, **kwargs):
        """
        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: initial :data:`SteinVIState`
        """
        rng_key, kernel_seed, model_seed, guide_seed = jax.random.split(
            rng_key, 4)
        model_init = handlers.seed(self.model, model_seed)
        guide_init = handlers.seed(self.guide, guide_seed)
        guide_trace = handlers.trace(guide_init).get_trace(
            *args, **kwargs, **self.static_kwargs)
        model_trace = handlers.trace(model_init).get_trace(
            *args, **kwargs, **self.static_kwargs)
        rng_key, particle_seed = jax.random.split(rng_key)
        guide_init_params = self._find_init_params(particle_seed, self.guide,
                                                   guide_trace)
        params = {}
        transforms = {}
        inv_transforms = {}
        particle_transforms = {}
        guide_param_names = set()
        should_enum = False
        for site in model_trace.values():
            if ("fn" in site and site["type"] == "sample"
                    and not site["is_observed"]
                    and isinstance(site["fn"], Distribution)
                    and site["fn"].is_discrete):
                if site["fn"].has_enumerate_support and self.enum:
                    should_enum = True
                else:
                    raise Exception(
                        "Cannot enumerate model with discrete variables without enumerate support"
                    )
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in chain(model_trace.values(), guide_trace.values()):
            if site["type"] == "param":
                transform = get_parameter_transform(site)
                inv_transforms[site["name"]] = transform
                transforms[site["name"]] = transform.inv
                particle_transforms[site["name"]] = site.get(
                    "particle_transform", IdentityTransform())
                if site["name"] in guide_init_params:
                    pval, _ = guide_init_params[site["name"]]
                    if self.classic_guide_params_fn(site["name"]):
                        pval = tree_map(lambda x: x[0], pval)
                else:
                    pval = site["value"]
                params[site["name"]] = transform.inv(pval)
                if site["name"] in guide_trace:
                    guide_param_names.add(site["name"])

        if should_enum:
            mpn = _guess_max_plate_nesting(model_trace)
            self._inference_model = enum(config_enumerate(self.model),
                                         -mpn - 1)
        self.guide_param_names = guide_param_names
        self.constrain_fn = partial(transform_fn, inv_transforms)
        self.uconstrain_fn = partial(transform_fn, transforms)
        self.particle_transforms = particle_transforms
        self.particle_transform_fn = partial(transform_fn, particle_transforms)
        stein_particles, _, _ = batch_ravel_pytree(
            {
                k: params[k]
                for k, site in guide_trace.items() if site["type"] == "param"
                and site["name"] in guide_init_params
            },
            nbatch_dims=1,
        )

        self.kernel_fn.init(kernel_seed, stein_particles.shape)
        return SteinVIState(self.optim.init(params), rng_key)