Esempio n. 1
0
def test_counterfactual_query(intervene, observe, flip):
    # x -> y -> z -> w

    sites = ["x", "y", "z", "w"]
    observations = {"x": 1., "y": None, "z": 1., "w": 1.}
    interventions = {"x": None, "y": 0., "z": 2., "w": 1.}

    def model():
        with handlers.seed(rng_seed=0):
            x = numpyro.sample("x", dist.Normal(0, 1))
            y = numpyro.sample("y", dist.Normal(x, 1))
            z = numpyro.sample("z", dist.Normal(y, 1))
            w = numpyro.sample("w", dist.Normal(z, 1))
            return dict(x=x, y=y, z=z, w=w)

    if not flip:
        if intervene:
            model = handlers.do(model, data=interventions)
        if observe:
            model = handlers.condition(model, data=observations)
    elif flip and intervene and observe:
        model = handlers.do(handlers.condition(model, data=observations),
                            data=interventions)

    with handlers.trace() as tr:
        actual_values = model()
    for name in sites:
        # case 1: purely observational query like handlers.condition
        if not intervene and observe:
            if observations[name] is not None:
                assert tr[name]['is_observed']
                assert_allclose(observations[name], actual_values[name])
                assert_allclose(observations[name], tr[name]['value'])
            if interventions[name] != observations[name]:
                if interventions[name] is not None:
                    assert_raises(AssertionError, assert_allclose,
                                  interventions[name], actual_values[name])
        # case 2: purely interventional query like old handlers.do
        elif intervene and not observe:
            assert not tr[name]['is_observed']
            if interventions[name] is not None:
                assert_allclose(interventions[name], actual_values[name])
            if observations[name] is not None:
                assert_raises(AssertionError, assert_allclose,
                              observations[name], tr[name]['value'])
            if interventions[name] is not None:
                assert_raises(AssertionError, assert_allclose,
                              interventions[name], tr[name]['value'])
        # case 3: counterfactual query mixing intervention and observation
        elif intervene and observe:
            if observations[name] is not None:
                assert tr[name]['is_observed']
                assert_allclose(observations[name], tr[name]['value'])
            if interventions[name] is not None:
                assert_allclose(interventions[name], actual_values[name])
            if interventions[name] != observations[name]:
                if interventions[name] is not None:
                    assert_raises(AssertionError, assert_allclose,
                                  interventions[name], tr[name]['value'])
Esempio n. 2
0
def test_condition():
    def model():
        x = numpyro.sample("x", dist.Delta(0.0))
        y = numpyro.sample("y", dist.Normal(0.0, 1.0))
        return x + y

    model = handlers.condition(handlers.seed(model, random.PRNGKey(1)), {"y": 2.0})
    model_trace = handlers.trace(model).get_trace()
    assert model_trace["y"]["value"] == 2.0
    assert model_trace["y"]["is_observed"]
    assert handlers.condition(model, {"y": 3.0})() == 3.0
Esempio n. 3
0
def test_condition():
    def model():
        x = numpyro.sample('x', dist.Delta(0.))
        y = numpyro.sample('y', dist.Normal(0., 1.))
        return x + y

    model = handlers.condition(handlers.seed(model, random.PRNGKey(1)),
                               {'y': 2.})
    model_trace = handlers.trace(model).get_trace()
    assert model_trace['y']['value'] == 2.
    assert model_trace['y']['is_observed']
    assert handlers.condition(model, {'y': 3.})() == 3.
Esempio n. 4
0
def test_condition():
    def model():
        x = numpyro.sample('x', dist.Delta(0.))
        y = numpyro.sample('y', dist.Normal(0., 1.))
        return x + y

    model = handlers.condition(handlers.seed(model, random.PRNGKey(1)),
                               {'y': 2.})
    model_trace = handlers.trace(model).get_trace()
    assert model_trace['y']['value'] == 2.
    assert model_trace['y']['is_observed']
    # Raise ValueError when site is already observed.
    with pytest.raises(ValueError):
        handlers.condition(model, {'y': 3.})()
Esempio n. 5
0
    def fit(self, df, iter=500, seed=42, **kwargs):
        teams = sorted(list(set(df["home_team"]) | set(df["away_team"])))
        home_team = df["home_team"].values
        away_team = df["away_team"].values
        home_goals = df["home_goals"].values
        away_goals = df["away_goals"].values
        gameweek = ((df["date"] - df["date"].min()).dt.days // 7).values

        self.team_to_index = {team: i for i, team in enumerate(teams)}
        self.index_to_team = {
            value: key
            for key, value in self.team_to_index.items()
        }
        self.n_teams = len(teams)
        self.min_date = df["date"].min()

        conditioned_model = condition(self.model,
                                      param_map={
                                          "home_goals": home_goals,
                                          "away_goals": away_goals
                                      })
        nuts_kernel = NUTS(conditioned_model)
        mcmc = MCMC(nuts_kernel,
                    num_warmup=iter // 2,
                    num_samples=iter,
                    **kwargs)
        rng_key = random.PRNGKey(seed)
        mcmc.run(rng_key, home_team, away_team, gameweek)

        self.samples = mcmc.get_samples()
        mcmc.print_summary()
        return self
Esempio n. 6
0
def predict(
    rng_key: np.ndarray, post_samples: np.ndarray, model: Callable, *args: Any, **kwargs: Any
) -> np.ndarray:

    model = handlers.seed(handlers.condition(model, post_samples), rng_key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["obs"]["value"]
Esempio n. 7
0
    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        with handlers.block():

            # we need to tell unconstrained messenger in potential energy computation
            # that only the item at time `i` is needed when transforming
            fn = handlers.infer_config(
                f, config_fn=lambda msg: {"_scan_current_index": i})

            seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == "condition":
                    seeded_fn = handlers.condition(seeded_fn,
                                                   condition_fn=subs_fn)
                elif subs_type == "substitute":
                    seeded_fn = handlers.substitute(seeded_fn,
                                                    substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
Esempio n. 8
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
Esempio n. 9
0
def test_no_split_deterministic():
    def model():
        x = numpyro.sample('x', dist.Normal(0., 1.))
        y = numpyro.sample('y', dist.Normal(0., 1.))
        return x + y

    model = handlers.condition(model, {'x': 1., 'y': 2.})
    assert model() == 3.
Esempio n. 10
0
def test_no_split_deterministic():
    def model():
        x = numpyro.sample("x", dist.Normal(0.0, 1.0))
        y = numpyro.sample("y", dist.Normal(0.0, 1.0))
        return x + y

    model = handlers.condition(model, {"x": 1.0, "y": 2.0})
    assert model() == 3.0
Esempio n. 11
0
def log_likelihood(
    rng_key: np.ndarray, params: np.ndarray, model: Callable, *args: Any, **kwargs: Any
) -> np.ndarray:

    model = handlers.condition(model, params)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    obs_node = model_trace["obs"]
    return obs_node["fn"].log_prob(obs_node["value"])
Esempio n. 12
0
 def single_prediction(rng, samples):
     model_trace = trace(seed(condition(model, samples),
                              rng)).get_trace(*args, **kwargs)
     sites = model_trace.keys() - samples.keys(
     ) if return_sites is None else return_sites
     return {
         name: site['value']
         for name, site in model_trace.items() if name in sites
     }
Esempio n. 13
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i)
                        and i in range(unroll_steps)) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        # we need to tell unconstrained messenger in potential energy computation
        # that only the item at time `i` is needed when transforming
        fn = handlers.infer_config(
            f, config_fn=lambda msg: {"_scan_current_index": i})

        seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == "condition":
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == "substitute":
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            # handler the name to match the pattern of sakkar_bilmes product
            with handlers.scope(prefix="_PREV_" * (unroll_steps - i),
                                divider=""):
                new_carry, y = config_enumerate(seeded_fn)(carry, x)
                trace = {}
        else:
            # Like scan_wrapper, we collect the trace of scan's transition function
            # `seeded_fn` here. To put time dimension to the correct position, we need to
            # promote shapes to make `fn` and `value`
            # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
            # and value's batch_shape is (3,), then we promote shape of
            # value so that its batch shape is (1, 3)).
            # Here we will promote `fn` shape first. `value` shape will be promoted after scanned.
            # We don't promote `value` shape here because we need to store carry shape
            # at this step. If we reshape the `value` here, output carry might get wrong shape.
            with _promote_fn_shapes(), packed_trace() as trace:
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            if len(carry_shapes) < (history + 1):
                carry_shapes.append(
                    [jnp.shape(x) for x in tree_flatten(new_carry)[0]])
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
Esempio n. 14
0
    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        with handlers.block():
            seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == 'condition':
                    seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
                elif subs_type == 'substitute':
                    seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
Esempio n. 15
0
    def single_prediction(val):
        rng_key, samples = val
        if infer_discrete:
            from numpyro.contrib.funsor import config_enumerate
            from numpyro.contrib.funsor.discrete import _sample_posterior

            model_trace = prototype_trace
            temperature = 1
            pred_samples = _sample_posterior(
                config_enumerate(condition(model, samples)),
                first_available_dim,
                temperature,
                rng_key,
                *model_args,
                **model_kwargs,
            )
        else:
            model_trace = trace(
                seed(substitute(masked_model, samples),
                     rng_key)).get_trace(*model_args, **model_kwargs)
            pred_samples = {
                name: site["value"]
                for name, site in model_trace.items()
            }

        if return_sites is not None:
            if return_sites == "":
                sites = {
                    k
                    for k, site in model_trace.items()
                    if site["type"] != "plate"
                }
            else:
                sites = return_sites
        else:
            sites = {
                k
                for k, site in model_trace.items()
                if (site["type"] == "sample" and k not in samples) or (
                    site["type"] == "deterministic")
            }
        return {
            name: value
            for name, value in pred_samples.items() if name in sites
        }
Esempio n. 16
0
    def wrapper(wrapped_operand):
        rng_key, operand = wrapped_operand

        with handlers.block():
            seeded_fn = handlers.seed(fn,
                                      rng_key) if rng_key is not None else fn
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map)
                if subs_type == "condition":
                    seeded_fn = handlers.condition(seeded_fn,
                                                   condition_fn=subs_fn)
                elif subs_type == "substitute":
                    seeded_fn = handlers.substitute(seeded_fn,
                                                    substitute_fn=subs_fn)

            with handlers.trace() as trace:
                value = seeded_fn(operand)

        return value, PytreeTrace(trace)
Esempio n. 17
0
def test_mcmc_model_side_enumeration(model, temperature):
    mcmc = infer.MCMC(infer.NUTS(model), 0, 1)
    mcmc.run(random.PRNGKey(0))
    mcmc_data = {
        k: v[0] for k, v in mcmc.get_samples().items() if k in ["loc", "scale"]
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    model = handlers.seed(model, rng_seed=1)
    actual_trace = handlers.trace(
        infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(config_enumerate(model), mcmc_trace),
            handlers.condition(config_enumerate(model), mcmc_data),
            temperature=temperature,
            rng_key=random.PRNGKey(1),
        )
    ).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace) == set(expected_trace)
Esempio n. 18
0
 def transformed_model_fn(*args, **kwargs):
     mapped_args, mapped_kwargs, fixed_obs = obs_to_model_args_fn(*args, **kwargs)
     return condition(model, data=fixed_obs)(*mapped_args, **mapped_kwargs)
Esempio n. 19
0
def _wrap_model(model, *args, **kwargs):
    gibbs_values = kwargs.pop("_gibbs_sites", {})
    with condition(data=gibbs_values), substitute(data=gibbs_values):
        return model(*args, **kwargs)
Esempio n. 20
0
 def fn(*args, **kwargs):
     gibbs_values = kwargs.pop("_gibbs_sites", {})
     with condition(data=gibbs_values), substitute(data=gibbs_values):
         model(*args, **kwargs)