Example #1
0
def test_scan_enum_plate():
    N, D = 10, 3
    data = random.normal(random.PRNGKey(0), (N, D))
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        x = None
        D_plate = numpyro.plate("D", D, dim=-1)
        for i, y in markov(enumerate(data)):
            with D_plate:
                probs = init_probs if x is None else transition_probs[x]
                x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
                numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)

    def fun_model(data):
        def transition_fn(x, y):
            probs = init_probs if x is None else transition_probs[x]
            with numpyro.plate("D", D, dim=-1):
                x = numpyro.sample("x", dist.Categorical(probs))
                numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(transition_fn, None, data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #2
0
def test_scan_enum_separated_plate_discrete():
    N, D = 10, 3
    data = random.normal(random.PRNGKey(0), (N, D))
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]])

    def model(data):
        x = 0
        D_plate = numpyro.plate("D", D, dim=-1)
        for i, y in markov(enumerate(data)):
            probs = transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            with D_plate:
                w = numpyro.sample(f"w_{i}", dist.Bernoulli(0.6))
                numpyro.sample(f"y_{i}", dist.Normal(Vindex(locs)[x, w], 1), obs=y)

    def fun_model(data):
        def transition_fn(x, y):
            probs = transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            with numpyro.plate("D", D, dim=-1):
                w = numpyro.sample("w", dist.Bernoulli(0.6))
                numpyro.sample("y", dist.Normal(Vindex(locs)[x, w], 1), obs=y)
            return x, None

        scan(transition_fn, 0, data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #3
0
def test_scan_enum_discrete_outside():
    data = random.normal(random.PRNGKey(0), (10,))
    probs = jnp.array([[[0.8, 0.2], [0.1, 0.9]],
                       [[0.7, 0.3], [0.6, 0.4]]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        w = numpyro.sample("w", dist.Bernoulli(0.6))
        x = 0
        for i, y in markov(enumerate(data)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs[w, x]))
            numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)

    def fun_model(data):
        w = numpyro.sample("w", dist.Bernoulli(0.6))

        def transition_fn(x, y):
            x = numpyro.sample("x", dist.Categorical(probs[w, x]))
            numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(transition_fn, 0, data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #4
0
def test_scan_enum_two_latents():
    num_steps = 11
    data = random.normal(random.PRNGKey(0), (num_steps,))
    probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]])
    locs = jnp.array([[-1.0, 1.0], [2.0, 3.0]])

    def model(data):
        x = w = 0
        for i, y in markov(enumerate(data)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_{i}", dist.Normal(locs[w, x], 1), obs=y)

    def fun_model(data):
        def transition_fn(carry, y):
            x, w = carry
            x = numpyro.sample("x", dist.Categorical(probs_x[x]))
            w = numpyro.sample("w", dist.Categorical(probs_w[w]))
            numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y)
            # also test if scan's `ys` are recorded corrected
            return (x, w), x

        scan(transition_fn, (0, 0), data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #5
0
def test_scan_enum_scan_enum():
    num_steps = 11
    data_x = random.normal(random.PRNGKey(0), (num_steps,))
    data_w = data_x[:-1] + 1
    probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]])
    locs_x = jnp.array([-1.0, 1.0])
    locs_w = jnp.array([2.0, 3.0])

    def model(data_x, data_w):
        x = w = 0
        for i, y in markov(enumerate(data_x)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            numpyro.sample(f"y_x_{i}", dist.Normal(locs_x[x], 1), obs=y)

        for i, y in markov(enumerate(data_w)):
            w = numpyro.sample(f"w{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_w_{i}", dist.Normal(locs_w[w], 1), obs=y)

    def fun_model(data_x, data_w):
        def transition_fn(name, probs, locs, x, y):
            x = numpyro.sample(name, dist.Categorical(probs[x]))
            numpyro.sample("y_" + name, dist.Normal(locs[x], 1), obs=y)
            return x, None

        scan(partial(transition_fn, "x", probs_x, locs_x), 0, data_x)
        scan(partial(transition_fn, "w", probs_w, locs_w), 0, data_w)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data_x, data_w), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data_x, data_w), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #6
0
def test_scan_enum_one_latent(num_steps):
    data = random.normal(random.PRNGKey(0), (num_steps,))
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data):
        x = None
        for i, y in markov(enumerate(data)):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            numpyro.sample(f"y_{i}", dist.Normal(locs[x], 1), obs=y)
        return x

    def fun_model(data):
        def transition_fn(x, y):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            numpyro.sample("y", dist.Normal(locs[x], 1), obs=y)
            return x, None

        x, collections = scan(transition_fn, None, data)
        assert collections is None
        return x

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)

    actual_last_x = enum(config_enumerate(fun_model))(data)
    expected_last_x = enum(config_enumerate(model))(data)
    assert_allclose(actual_last_x, expected_last_x)
Example #7
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i)
                        and i in range(unroll_steps)) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        # we need to tell unconstrained messenger in potential energy computation
        # that only the item at time `i` is needed when transforming
        fn = handlers.infer_config(
            f, config_fn=lambda msg: {"_scan_current_index": i})

        seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == "condition":
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == "substitute":
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            # handler the name to match the pattern of sakkar_bilmes product
            with handlers.scope(prefix="_PREV_" * (unroll_steps - i),
                                divider=""):
                new_carry, y = config_enumerate(seeded_fn)(carry, x)
                trace = {}
        else:
            # Like scan_wrapper, we collect the trace of scan's transition function
            # `seeded_fn` here. To put time dimension to the correct position, we need to
            # promote shapes to make `fn` and `value`
            # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
            # and value's batch_shape is (3,), then we promote shape of
            # value so that its batch shape is (1, 3)).
            # Here we will promote `fn` shape first. `value` shape will be promoted after scanned.
            # We don't promote `value` shape here because we need to store carry shape
            # at this step. If we reshape the `value` here, output carry might get wrong shape.
            with _promote_fn_shapes(), packed_trace() as trace:
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            if len(carry_shapes) < (history + 1):
                carry_shapes.append(
                    [jnp.shape(x) for x in tree_flatten(new_carry)[0]])
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
Example #8
0
def test_scan_hmm_smoke(length, temperature):

    # This should match the example in the infer_discrete docstring.
    def hmm(data, hidden_dim=10):
        transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
        means = jnp.arange(float(hidden_dim))

        def transition_fn(state, y):
            state = numpyro.sample("states",
                                   dist.Categorical(transition[state]))
            y = numpyro.sample("obs", dist.Normal(means[state], 1.0), obs=y)
            return state, (state, y)

        _, (states, data) = scan(transition_fn, 0, data, length=length)

        return [0] + [s for s in states], data

    true_states, data = handlers.seed(hmm, 0)(None)
    assert len(data) == length
    assert len(true_states) == 1 + len(data)

    decoder = infer_discrete(config_enumerate(hmm),
                             temperature=temperature,
                             rng_key=random.PRNGKey(1))
    inferred_states, _ = decoder(data)
    assert len(inferred_states) == len(true_states)

    logger.info("true states: {}".format(list(map(int, true_states))))
    logger.info("inferred states: {}".format(list(map(int, inferred_states))))
Example #9
0
def test_hmm_smoke(length, temperature):

    # This should match the example in the infer_discrete docstring.
    def hmm(data, hidden_dim=10):
        transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
        means = jnp.arange(float(hidden_dim))
        states = [0]
        for t in markov(range(len(data))):
            states.append(
                numpyro.sample(
                    "states_{}".format(t), dist.Categorical(transition[states[-1]])
                )
            )
            data[t] = numpyro.sample(
                "obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t]
            )
        return states, data

    true_states, data = handlers.seed(hmm, 0)([None] * length)
    assert len(data) == length
    assert len(true_states) == 1 + len(data)

    decoder = infer_discrete(
        config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1)
    )
    inferred_states, _ = decoder(data)
    assert len(inferred_states) == len(true_states)

    logger.info("true states: {}".format(list(map(int, true_states))))
    logger.info("inferred states: {}".format(list(map(int, inferred_states))))
Example #10
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
Example #11
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        # convert to unconstrained values
        z_hmc = {
            k: biject_to(prototype_trace[k]["fn"].support).inv(v)
            for k, v in hmc_sites.items()
            if k in prototype_trace and prototype_trace[k]["type"] == "sample"
        }
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        wrapped_model = _wrap_model(model)
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model = enum(config_enumerate(wrapped_model),
                                 -max_plate_nesting - 1)

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
Example #12
0
def test_scan_history(history, T):
    def model():
        p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
        q = numpyro.param("q", 0.25 * jnp.ones(2))
        z = numpyro.sample("z", dist.Bernoulli(0.5))
        x_prev = 0
        x_curr = 0
        for t in markov(range(T), history=history):
            probs = p[x_prev, x_curr, z]
            x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t),
                                                    dist.Bernoulli(probs))
            numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0)
        return x_prev, x_curr

    def fun_model():
        p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2)))
        q = numpyro.param("q", 0.25 * jnp.ones(2))
        z = numpyro.sample("z", dist.Bernoulli(0.5))

        def transition_fn(carry, y):
            x_prev, x_curr = carry
            probs = p[x_prev, x_curr, z]
            x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs))
            numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y)
            return (x_prev, x_curr), None

        (x_prev, x_curr), _ = scan(transition_fn, (0, 0),
                                   jnp.zeros(T),
                                   history=history)
        return x_prev, x_curr

    expected_log_joint = log_density(enum(config_enumerate(model)), (), {},
                                     {})[0]
    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (), {},
                                   {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)

    expected_x_prev, expected_x_curr = enum(config_enumerate(model))()
    actual_x_prev, actual_x_curr = enum(config_enumerate(fun_model))()
    assert_allclose(actual_x_prev, expected_x_prev)
    assert_allclose(actual_x_curr, expected_x_curr)
Example #13
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        z_hmc = hmc_sites
        use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0
        if use_enum:
            from numpyro.contrib.funsor import config_enumerate, enum

            wrapped_model_ = enum(config_enumerate(wrapped_model),
                                  -max_plate_nesting - 1)
        else:
            wrapped_model_ = wrapped_model

        def potential_fn(z_discrete):
            model_kwargs_ = model_kwargs.copy()
            model_kwargs_["_gibbs_sites"] = z_discrete
            return potential_energy(wrapped_model_,
                                    model_args,
                                    model_kwargs_,
                                    z_hmc,
                                    enum=use_enum)

        # get support_sizes of gibbs_sites
        support_sizes_flat, _ = ravel_pytree(
            {k: support_sizes[k]
             for k in gibbs_sites})
        num_discretes = support_sizes_flat.shape[0]

        rng_key, rng_permute = random.split(rng_key)
        idxs = random.permutation(rng_key, jnp.arange(num_discretes))

        def body_fn(i, val):
            idx = idxs[i]
            support_size = support_sizes_flat[idx]
            rng_key, z, pe = val
            rng_key, z_new, pe_new, log_accept_ratio = proposal_fn(
                rng_key,
                z,
                pe,
                potential_fn=potential_fn,
                idx=idx,
                support_size=support_size)
            rng_key, rng_accept = random.split(rng_key)
            # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio
            # and -log(u) ~ exponential(1)
            z, pe = cond(
                random.exponential(rng_accept) > -log_accept_ratio,
                (z_new, pe_new), identity, (z, pe), identity)
            return rng_key, z, pe

        init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites))
        _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val)
        return gibbs_sites
Example #14
0
def test_scan_enum_separated_plates_same_dim():
    N, D1, D2 = 10, 3, 4
    data = random.normal(random.PRNGKey(0), (N, D1 + D2))
    data1, data2 = data[:, :D1], data[:, D1:]
    init_probs = jnp.array([0.6, 0.4])
    transition_probs = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    locs = jnp.array([-1.0, 1.0])

    def model(data1, data2):
        x = None
        D1_plate = numpyro.plate("D1", D1, dim=-1)
        D2_plate = numpyro.plate("D2", D2, dim=-1)
        for i, (y1, y2) in markov(enumerate(zip(data1, data2))):
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs))
            with D1_plate:
                numpyro.sample(f"y1_{i}", dist.Normal(locs[x], 1), obs=y1)
            with D2_plate:
                numpyro.sample(f"y2_{i}", dist.Normal(locs[x], 1), obs=y2)

    def fun_model(data1, data2):
        def transition_fn(x, y):
            y1, y2 = y
            probs = init_probs if x is None else transition_probs[x]
            x = numpyro.sample("x", dist.Categorical(probs))
            with numpyro.plate("D1", D1, dim=-1):
                numpyro.sample("y1", dist.Normal(locs[x], 1), obs=y1)
            with numpyro.plate("D2", D2, dim=-1):
                numpyro.sample("y2", dist.Normal(locs[x], 1), obs=y2)
            return x, None

        scan(transition_fn, None, (data1, data2))

    actual_log_joint = log_density(enum(config_enumerate(fun_model), -2),
                                   (data1, data2), {}, {})[0]
    expected_log_joint = log_density(enum(config_enumerate(model), -2),
                                     (data1, data2), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)
Example #15
0
def test_mcmc_model_side_enumeration(model, temperature):
    mcmc = infer.MCMC(infer.NUTS(config_enumerate(model)),
                      num_warmup=0,
                      num_samples=1)
    mcmc.run(random.PRNGKey(0))
    mcmc_data = {
        k: v[0]
        for k, v in mcmc.get_samples().items() if k in ["loc", "scale"]
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    model = handlers.seed(model, rng_seed=1)
    actual_trace = handlers.trace(
        infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(config_enumerate(model), mcmc_trace),
            handlers.condition(config_enumerate(model), mcmc_data),
            temperature=temperature,
            rng_key=random.PRNGKey(1),
        )).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace) == set(expected_trace)
Example #16
0
def main(_args):
    data = generate_data()
    init_rng_key = PRNGKey(1273)
    # nuts = NUTS(gmm)
    # mcmc = MCMC(nuts, 100, 1000)
    # mcmc.print_summary()
    seeded_gmm = seed(gmm, init_rng_key)
    model_trace = trace(seeded_gmm).get_trace(data)
    max_plate_nesting = _guess_max_plate_nesting(model_trace)
    enum_gmm = enum(config_enumerate(gmm), - max_plate_nesting - 1)
    svi = SVI(enum_gmm, gmm_guide, Adam(0.1), RenyiELBO(-10.))
    svi_state = svi.init(init_rng_key, data)
    upd_fun = jax.jit(svi.update)
    with tqdm.trange(100_000) as pbar:
        for i in pbar:
            svi_state, loss = upd_fun(svi_state, data)
            pbar.set_description(f"SVI {loss}", True)
Example #17
0
    def single_prediction(val):
        rng_key, samples = val
        if infer_discrete:
            from numpyro.contrib.funsor import config_enumerate
            from numpyro.contrib.funsor.discrete import _sample_posterior

            model_trace = prototype_trace
            temperature = 1
            pred_samples = _sample_posterior(
                config_enumerate(condition(model, samples)),
                first_available_dim,
                temperature,
                rng_key,
                *model_args,
                **model_kwargs,
            )
        else:
            model_trace = trace(
                seed(substitute(masked_model, samples),
                     rng_key)).get_trace(*model_args, **model_kwargs)
            pred_samples = {
                name: site["value"]
                for name, site in model_trace.items()
            }

        if return_sites is not None:
            if return_sites == "":
                sites = {
                    k
                    for k, site in model_trace.items()
                    if site["type"] != "plate"
                }
            else:
                sites = return_sites
        else:
            sites = {
                k
                for k, site in model_trace.items()
                if (site["type"] == "sample" and k not in samples) or (
                    site["type"] == "deterministic")
            }
        return {
            name: value
            for name, value in pred_samples.items() if name in sites
        }
Example #18
0
    def init(self, rng_key, *args, **kwargs):
        """
        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: initial :data:`SteinVIState`
        """
        rng_key, kernel_seed, model_seed, guide_seed = jax.random.split(
            rng_key, 4)
        model_init = handlers.seed(self.model, model_seed)
        guide_init = handlers.seed(self.guide, guide_seed)
        guide_trace = handlers.trace(guide_init).get_trace(
            *args, **kwargs, **self.static_kwargs)
        model_trace = handlers.trace(model_init).get_trace(
            *args, **kwargs, **self.static_kwargs)
        rng_key, particle_seed = jax.random.split(rng_key)
        guide_init_params = self._find_init_params(particle_seed, self.guide,
                                                   guide_trace)
        params = {}
        transforms = {}
        inv_transforms = {}
        particle_transforms = {}
        guide_param_names = set()
        should_enum = False
        for site in model_trace.values():
            if ("fn" in site and site["type"] == "sample"
                    and not site["is_observed"]
                    and isinstance(site["fn"], Distribution)
                    and site["fn"].is_discrete):
                if site["fn"].has_enumerate_support and self.enum:
                    should_enum = True
                else:
                    raise Exception(
                        "Cannot enumerate model with discrete variables without enumerate support"
                    )
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in chain(model_trace.values(), guide_trace.values()):
            if site["type"] == "param":
                transform = get_parameter_transform(site)
                inv_transforms[site["name"]] = transform
                transforms[site["name"]] = transform.inv
                particle_transforms[site["name"]] = site.get(
                    "particle_transform", IdentityTransform())
                if site["name"] in guide_init_params:
                    pval, _ = guide_init_params[site["name"]]
                    if self.classic_guide_params_fn(site["name"]):
                        pval = tree_map(lambda x: x[0], pval)
                else:
                    pval = site["value"]
                params[site["name"]] = transform.inv(pval)
                if site["name"] in guide_trace:
                    guide_param_names.add(site["name"])

        if should_enum:
            mpn = _guess_max_plate_nesting(model_trace)
            self._inference_model = enum(config_enumerate(self.model),
                                         -mpn - 1)
        self.guide_param_names = guide_param_names
        self.constrain_fn = partial(transform_fn, inv_transforms)
        self.uconstrain_fn = partial(transform_fn, transforms)
        self.particle_transforms = particle_transforms
        self.particle_transform_fn = partial(transform_fn, particle_transforms)
        stein_particles, _, _ = batch_ravel_pytree(
            {
                k: params[k]
                for k, site in guide_trace.items() if site["type"] == "param"
                and site["name"] in guide_init_params
            },
            nbatch_dims=1,
        )

        self.kernel_fn.init(kernel_seed, stein_particles.shape)
        return SteinVIState(self.optim.init(params), rng_key)
Example #19
0
def initialize_model(rng_key,
                     model,
                     init_strategy=init_to_uniform,
                     dynamic_args=False,
                     model_args=(),
                     model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: a namedtupe `ModelInfo` which contains the fields
        (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
        `param_info` is a namedtuple `ParamInfo` containing values from the prior
        used to initiate MCMC, their corresponding potential energy, and their gradients;
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    substituted_model = substitute(seed(
        model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
                                   substitute_fn=init_strategy)
    inv_transforms, replay_model, has_enumerate_support, model_trace = _get_model_transforms(
        substituted_model, model_args, model_kwargs)
    constrained_values = {
        k: v['value']
        for k, v in model_trace.items() if v['type'] == 'sample'
        and not v['is_observed'] and not v['fn'].is_discrete
    }

    if has_enumerate_support:
        from numpyro.contrib.funsor import config_enumerate, enum

        if not isinstance(model, enum):
            max_plate_nesting = _guess_max_plate_nesting(model_trace)
            model = enum(config_enumerate(model), -max_plate_nesting - 1)

    potential_fn, postprocess_fn = get_potential_fn(model,
                                                    inv_transforms,
                                                    replay_model=replay_model,
                                                    enum=has_enumerate_support,
                                                    dynamic_args=dynamic_args,
                                                    model_args=model_args,
                                                    model_kwargs=model_kwargs)

    init_strategy = init_strategy if isinstance(init_strategy,
                                                partial) else init_strategy()
    if (init_strategy.func is init_to_value) and not replay_model:
        init_values = init_strategy.keywords.get("values")
        unconstrained_values = transform_fn(inv_transforms,
                                            init_values,
                                            invert=True)
        init_strategy = _init_to_unconstrained_value(
            values=unconstrained_values)
    prototype_params = transform_fn(inv_transforms,
                                    constrained_values,
                                    invert=True)
    (init_params, pe, grad), is_valid = find_valid_initial_params(
        rng_key,
        model,
        init_strategy=init_strategy,
        enum=has_enumerate_support,
        model_args=model_args,
        model_kwargs=model_kwargs,
        prototype_params=prototype_params)

    if not_jax_tracer(is_valid):
        if device_get(~jnp.all(is_valid)):
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn,
                     postprocess_fn, model_trace)
Example #20
0
def initialize_model(
    rng_key,
    model,
    *,
    init_strategy=init_to_uniform,
    dynamic_args=False,
    model_args=(),
    model_kwargs=None,
    forward_mode_differentiation=False,
    validate_grad=True,
):
    """
    (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
    and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
    to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param callable init_strategy: a per-site initialization function.
        See :ref:`init_strategy` section for available functions.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param bool forward_mode_differentiation: whether to use forward-mode differentiation
        or reverse-mode differentiation. By default, we use reverse mode but the forward
        mode can be useful in some cases to improve the performance. In addition, some
        control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
        only supports forward-mode differentiation. See
        `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
        for more information.
    :param bool validate_grad: whether to validate gradient of the initial params.
        Defaults to True.
    :return: a namedtupe `ModelInfo` which contains the fields
        (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
        `param_info` is a namedtuple `ParamInfo` containing values from the prior
        used to initiate MCMC, their corresponding potential energy, and their gradients;
        `postprocess_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support, in addition to returning values
        at `deterministic` sites in the model.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    substituted_model = substitute(
        seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
        substitute_fn=init_strategy,
    )
    (
        inv_transforms,
        replay_model,
        has_enumerate_support,
        model_trace,
    ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
    # substitute param sites from model_trace to model so
    # we don't need to generate again parameters of `numpyro.module`
    model = substitute(
        model,
        data={
            k: site["value"]
            for k, site in model_trace.items() if site["type"] in ["param"]
        },
    )
    constrained_values = {
        k: v["value"]
        for k, v in model_trace.items() if v["type"] == "sample"
        and not v["is_observed"] and not v["fn"].is_discrete
    }

    if has_enumerate_support:
        from numpyro.contrib.funsor import config_enumerate, enum

        if not isinstance(model, enum):
            max_plate_nesting = _guess_max_plate_nesting(model_trace)
            _validate_model(model_trace)
            model = enum(config_enumerate(model), -max_plate_nesting - 1)

    potential_fn, postprocess_fn = get_potential_fn(
        model,
        inv_transforms,
        replay_model=replay_model,
        enum=has_enumerate_support,
        dynamic_args=dynamic_args,
        model_args=model_args,
        model_kwargs=model_kwargs,
    )

    init_strategy = (init_strategy if isinstance(init_strategy, partial) else
                     init_strategy())
    if (init_strategy.func is init_to_value) and not replay_model:
        init_values = init_strategy.keywords.get("values")
        unconstrained_values = transform_fn(inv_transforms,
                                            init_values,
                                            invert=True)
        init_strategy = _init_to_unconstrained_value(
            values=unconstrained_values)
    prototype_params = transform_fn(inv_transforms,
                                    constrained_values,
                                    invert=True)
    (init_params, pe, grad), is_valid = find_valid_initial_params(
        rng_key,
        substitute(
            model,
            data={
                k: site["value"]
                for k, site in model_trace.items()
                if site["type"] in ["plate"]
            },
        ),
        init_strategy=init_strategy,
        enum=has_enumerate_support,
        model_args=model_args,
        model_kwargs=model_kwargs,
        prototype_params=prototype_params,
        forward_mode_differentiation=forward_mode_differentiation,
        validate_grad=validate_grad,
    )

    if not_jax_tracer(is_valid):
        if device_get(~jnp.all(is_valid)):
            with numpyro.validation_enabled(), trace() as tr:
                # validate parameters
                substituted_model(*model_args, **model_kwargs)
                # validate values
                for site in tr.values():
                    if site["type"] == "sample":
                        with warnings.catch_warnings(record=True) as ws:
                            site["fn"]._validate_sample(site["value"])
                        if len(ws) > 0:
                            for w in ws:
                                # at site information to the warning message
                                w.message.args = ("Site {}: {}".format(
                                    site["name"],
                                    w.message.args[0]), ) + w.message.args[1:]
                                warnings.showwarning(
                                    w.message,
                                    w.category,
                                    w.filename,
                                    w.lineno,
                                    file=w.file,
                                    line=w.line,
                                )
            raise RuntimeError(
                "Cannot find valid initial parameters. Please check your model again."
            )
    return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn,
                     postprocess_fn, model_trace)