def test_counterfactual_query(intervene, observe, flip): # x -> y -> z -> w sites = ["x", "y", "z", "w"] observations = {"x": 1., "y": None, "z": 1., "w": 1.} interventions = {"x": None, "y": 0., "z": 2., "w": 1.} def model(): with handlers.seed(rng_seed=0): x = numpyro.sample("x", dist.Normal(0, 1)) y = numpyro.sample("y", dist.Normal(x, 1)) z = numpyro.sample("z", dist.Normal(y, 1)) w = numpyro.sample("w", dist.Normal(z, 1)) return dict(x=x, y=y, z=z, w=w) if not flip: if intervene: model = handlers.do(model, data=interventions) if observe: model = handlers.condition(model, data=observations) elif flip and intervene and observe: model = handlers.do(handlers.condition(model, data=observations), data=interventions) with handlers.trace() as tr: actual_values = model() for name in sites: # case 1: purely observational query like handlers.condition if not intervene and observe: if observations[name] is not None: assert tr[name]['is_observed'] assert_allclose(observations[name], actual_values[name]) assert_allclose(observations[name], tr[name]['value']) if interventions[name] != observations[name]: if interventions[name] is not None: assert_raises(AssertionError, assert_allclose, interventions[name], actual_values[name]) # case 2: purely interventional query like old handlers.do elif intervene and not observe: assert not tr[name]['is_observed'] if interventions[name] is not None: assert_allclose(interventions[name], actual_values[name]) if observations[name] is not None: assert_raises(AssertionError, assert_allclose, observations[name], tr[name]['value']) if interventions[name] is not None: assert_raises(AssertionError, assert_allclose, interventions[name], tr[name]['value']) # case 3: counterfactual query mixing intervention and observation elif intervene and observe: if observations[name] is not None: assert tr[name]['is_observed'] assert_allclose(observations[name], tr[name]['value']) if interventions[name] is not None: assert_allclose(interventions[name], actual_values[name]) if interventions[name] != observations[name]: if interventions[name] is not None: assert_raises(AssertionError, assert_allclose, interventions[name], tr[name]['value'])
def test_condition(): def model(): x = numpyro.sample("x", dist.Delta(0.0)) y = numpyro.sample("y", dist.Normal(0.0, 1.0)) return x + y model = handlers.condition(handlers.seed(model, random.PRNGKey(1)), {"y": 2.0}) model_trace = handlers.trace(model).get_trace() assert model_trace["y"]["value"] == 2.0 assert model_trace["y"]["is_observed"] assert handlers.condition(model, {"y": 3.0})() == 3.0
def test_condition(): def model(): x = numpyro.sample('x', dist.Delta(0.)) y = numpyro.sample('y', dist.Normal(0., 1.)) return x + y model = handlers.condition(handlers.seed(model, random.PRNGKey(1)), {'y': 2.}) model_trace = handlers.trace(model).get_trace() assert model_trace['y']['value'] == 2. assert model_trace['y']['is_observed'] assert handlers.condition(model, {'y': 3.})() == 3.
def test_condition(): def model(): x = numpyro.sample('x', dist.Delta(0.)) y = numpyro.sample('y', dist.Normal(0., 1.)) return x + y model = handlers.condition(handlers.seed(model, random.PRNGKey(1)), {'y': 2.}) model_trace = handlers.trace(model).get_trace() assert model_trace['y']['value'] == 2. assert model_trace['y']['is_observed'] # Raise ValueError when site is already observed. with pytest.raises(ValueError): handlers.condition(model, {'y': 3.})()
def fit(self, df, iter=500, seed=42, **kwargs): teams = sorted(list(set(df["home_team"]) | set(df["away_team"]))) home_team = df["home_team"].values away_team = df["away_team"].values home_goals = df["home_goals"].values away_goals = df["away_goals"].values gameweek = ((df["date"] - df["date"].min()).dt.days // 7).values self.team_to_index = {team: i for i, team in enumerate(teams)} self.index_to_team = { value: key for key, value in self.team_to_index.items() } self.n_teams = len(teams) self.min_date = df["date"].min() conditioned_model = condition(self.model, param_map={ "home_goals": home_goals, "away_goals": away_goals }) nuts_kernel = NUTS(conditioned_model) mcmc = MCMC(nuts_kernel, num_warmup=iter // 2, num_samples=iter, **kwargs) rng_key = random.PRNGKey(seed) mcmc.run(rng_key, home_team, away_team, gameweek) self.samples = mcmc.get_samples() mcmc.print_summary() return self
def predict( rng_key: np.ndarray, post_samples: np.ndarray, model: Callable, *args: Any, **kwargs: Any ) -> np.ndarray: model = handlers.seed(handlers.condition(model, post_samples), rng_key) model_trace = handlers.trace(model).get_trace(*args, **kwargs) return model_trace["obs"]["value"]
def body_fn(wrapped_carry, x): i, rng_key, carry = wrapped_carry rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) with handlers.block(): # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) with handlers.trace() as trace: carry, y = seeded_fn(carry, x) return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
def test_no_split_deterministic(): def model(): x = numpyro.sample('x', dist.Normal(0., 1.)) y = numpyro.sample('y', dist.Normal(0., 1.)) return x + y model = handlers.condition(model, {'x': 1., 'y': 2.}) assert model() == 3.
def test_no_split_deterministic(): def model(): x = numpyro.sample("x", dist.Normal(0.0, 1.0)) y = numpyro.sample("y", dist.Normal(0.0, 1.0)) return x + y model = handlers.condition(model, {"x": 1.0, "y": 2.0}) assert model() == 3.0
def log_likelihood( rng_key: np.ndarray, params: np.ndarray, model: Callable, *args: Any, **kwargs: Any ) -> np.ndarray: model = handlers.condition(model, params) model_trace = handlers.trace(model).get_trace(*args, **kwargs) obs_node = model_trace["obs"] return obs_node["fn"].log_prob(obs_node["value"])
def single_prediction(rng, samples): model_trace = trace(seed(condition(model, samples), rng)).get_trace(*args, **kwargs) sites = model_trace.keys() - samples.keys( ) if return_sites is None else return_sites return { name: site['value'] for name, site in model_trace.items() if name in sites }
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append( [jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
def body_fn(wrapped_carry, x): i, rng_key, carry = wrapped_carry rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) with handlers.block(): seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) with handlers.trace() as trace: carry, y = seeded_fn(carry, x) return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
def single_prediction(val): rng_key, samples = val if infer_discrete: from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.funsor.discrete import _sample_posterior model_trace = prototype_trace temperature = 1 pred_samples = _sample_posterior( config_enumerate(condition(model, samples)), first_available_dim, temperature, rng_key, *model_args, **model_kwargs, ) else: model_trace = trace( seed(substitute(masked_model, samples), rng_key)).get_trace(*model_args, **model_kwargs) pred_samples = { name: site["value"] for name, site in model_trace.items() } if return_sites is not None: if return_sites == "": sites = { k for k, site in model_trace.items() if site["type"] != "plate" } else: sites = return_sites else: sites = { k for k, site in model_trace.items() if (site["type"] == "sample" and k not in samples) or ( site["type"] == "deterministic") } return { name: value for name, value in pred_samples.items() if name in sites }
def wrapper(wrapped_operand): rng_key, operand = wrapped_operand with handlers.block(): seeded_fn = handlers.seed(fn, rng_key) if rng_key is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) with handlers.trace() as trace: value = seeded_fn(operand) return value, PytreeTrace(trace)
def test_mcmc_model_side_enumeration(model, temperature): mcmc = infer.MCMC(infer.NUTS(model), 0, 1) mcmc.run(random.PRNGKey(0)) mcmc_data = { k: v[0] for k, v in mcmc.get_samples().items() if k in ["loc", "scale"] } # MAP estimate discretes, conditioned on posterior sampled continous latents. model = handlers.seed(model, rng_seed=1) actual_trace = handlers.trace( infer_discrete( # TODO support replayed sites in infer_discrete. # handlers.replay(config_enumerate(model), mcmc_trace), handlers.condition(config_enumerate(model), mcmc_data), temperature=temperature, rng_key=random.PRNGKey(1), ) ).get_trace() # Check site names and shapes. expected_trace = handlers.trace(model).get_trace() assert set(actual_trace) == set(expected_trace)
def transformed_model_fn(*args, **kwargs): mapped_args, mapped_kwargs, fixed_obs = obs_to_model_args_fn(*args, **kwargs) return condition(model, data=fixed_obs)(*mapped_args, **mapped_kwargs)
def _wrap_model(model, *args, **kwargs): gibbs_values = kwargs.pop("_gibbs_sites", {}) with condition(data=gibbs_values), substitute(data=gibbs_values): return model(*args, **kwargs)
def fn(*args, **kwargs): gibbs_values = kwargs.pop("_gibbs_sites", {}) with condition(data=gibbs_values), substitute(data=gibbs_values): model(*args, **kwargs)