def test_mcmc_model_side_enumeration(model, temperature): # Perform fake inference. # Draw from prior rather than trying to sample from mcmc posterior. # This has the wrong distribution but the right type for tests. mcmc_trace = handlers.trace( handlers.block(handlers.enum(infer.config_enumerate(model)), expose=["loc", "scale"])).get_trace() mcmc_data = { name: site["value"] for name, site in mcmc_trace.nodes.items() if site["type"] == "sample" } # MAP estimate discretes, conditioned on posterior sampled continous latents. actual_trace = handlers.trace( infer.infer_discrete( # TODO support replayed sites in infer_discrete. # handlers.replay(infer.config_enumerate(model), mcmc_trace), handlers.condition(infer.config_enumerate(model), mcmc_data), temperature=temperature, ), ).get_trace() # Check site names and shapes. expected_trace = handlers.trace(model).get_trace() assert set(actual_trace.nodes) == set(expected_trace.nodes) assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
def test_svi_model_side_enumeration(model, temperature): # Perform fake inference. # This has the wrong distribution but the right type for tests. guide = AutoNormal( handlers.enum( handlers.block(infer.config_enumerate(model), expose=["loc", "scale"]))) guide() # Initialize but don't bother to train. guide_trace = handlers.trace(guide).get_trace() guide_data = { name: site["value"] for name, site in guide_trace.nodes.items() if site["type"] == "sample" } # MAP estimate discretes, conditioned on posterior sampled continous latents. actual_trace = handlers.trace( infer.infer_discrete( # TODO support replayed sites in infer_discrete. # handlers.replay(infer.config_enumerate(model), guide_trace) handlers.condition(infer.config_enumerate(model), guide_data), temperature=temperature, )).get_trace() # Check site names and shapes. expected_trace = handlers.trace(model).get_trace() assert set(actual_trace.nodes) == set(expected_trace.nodes) assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
def test_model_enumerated_elbo(model, guide, data, history): pyro.clear_param_store() with pyro_backend("contrib.funsor"): if history > 1: pytest.xfail( reason="TraceMarkovEnum_ELBO does not yet support history > 1") model = infer.config_enumerate(model, default="parallel") elbo = infer.TraceEnum_ELBO(max_plate_nesting=4) expected_loss = elbo.loss_and_grads(model, guide, data, history, False) expected_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) vectorized_elbo = infer.TraceMarkovEnum_ELBO(max_plate_nesting=4) actual_loss = vectorized_elbo.loss_and_grads(model, guide, data, history, True) actual_grads = ( value.grad for name, value in pyro.get_param_store().named_parameters()) assert_close(actual_loss, expected_loss) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_close(actual_grad, expected_grad)
def test_hmm_smoke(length, temperature): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim) means = torch.arange(float(hidden_dim)) states = [0] for t in pyro.markov(range(len(data))): states.append( pyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) data[t] = pyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t]) return states, data true_states, data = hmm([None] * length) assert len(data) == length assert len(true_states) == 1 + len(data) decoder = infer.infer_discrete(infer.config_enumerate(hmm), temperature=temperature) inferred_states, _ = decoder(data) assert len(inferred_states) == len(true_states) logger.info("true states: {}".format(list(map(int, true_states)))) logger.info("inferred states: {}".format(list(map(int, inferred_states))))
def _guide_from_model(model): try: with pyro_backend("contrib.funsor"): return handlers.block( infer.config_enumerate(model, default="parallel"), lambda msg: msg.get("is_observed", False)) except KeyError: # for test collection without funsor return model
def test_tmc_categoricals(depth, max_plate_nesting, num_samples, tmc_strategy): def model(): x = pyro.sample("x0", dist.Categorical(pyro.param("q0"))) with pyro.plate("local", 3): for i in range(1, depth): x = pyro.sample( "x{}".format(i), dist.Categorical(pyro.param("q{}".format(i))[..., x, :])) with pyro.plate("data", 4): pyro.sample("y", dist.Bernoulli(pyro.param("qy")[..., x]), obs=data) with pyro_backend("pyro"): # initialize qs = [pyro.param("q0", torch.tensor([0.4, 0.6], requires_grad=True))] for i in range(1, depth): qs.append( pyro.param("q{}".format(i), torch.randn(2, 2).abs().detach().requires_grad_(), constraint=constraints.simplex)) qs.append( pyro.param("qy", torch.tensor([0.75, 0.25], requires_grad=True))) qs = [q.unconstrained() for q in qs] data = (torch.rand(4, 3) > 0.5).to(dtype=qs[-1].dtype, device=qs[-1].device) with pyro_backend("pyro"): elbo = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) enum_model = infer.config_enumerate(model, default="parallel", expand=False, num_samples=num_samples, tmc=tmc_strategy) expected_loss = ( -elbo.differentiable_loss(enum_model, lambda: None)).exp() expected_grads = grad(expected_loss, qs) with pyro_backend("contrib.funsor"): tmc = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = infer.config_enumerate(model, default="parallel", expand=False, num_samples=num_samples, tmc=tmc_strategy) actual_loss = (-tmc.differentiable_loss(tmc_model, lambda: None)).exp() actual_grads = grad(actual_loss, qs) prec = 0.05 assert_equal(actual_loss, expected_loss, prec=prec, msg="".join([ "\nexpected loss = {}".format(expected_loss), "\n actual loss = {}".format(actual_loss), ])) for actual_grad, expected_grad in zip(actual_grads, expected_grads): assert_equal(actual_grad, expected_grad, prec=prec, msg="".join([ "\nexpected grad = {}".format( expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format( actual_grad.detach().cpu().numpy()), ]))
def test_tmc_normals_chain_gradient(depth, num_samples, max_plate_nesting, expand, guide_type, reparameterized, tmc_strategy): def model(reparameterized): Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) pyro.sample("y", Normal(x, 1.), obs=torch.tensor(float(1))) def factorized_guide(reparameterized): Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): pyro.sample("x{}".format(i), Normal(0., math.sqrt(float(i + 1) / depth))) def nonfactorized_guide(reparameterized): Normal = dist.Normal if reparameterized else dist.testing.fakes.NonreparameterizedNormal x = pyro.sample("x0", Normal(pyro.param("q2"), math.sqrt(1. / depth))) for i in range(1, depth): x = pyro.sample("x{}".format(i), Normal(x, math.sqrt(1. / depth))) with pyro_backend("contrib.funsor"): # compare reparameterized and nonreparameterized gradient estimates q2 = pyro.param("q2", torch.tensor(0.5, requires_grad=True)) qs = (q2.unconstrained(), ) tmc = infer.TraceTMC_ELBO(max_plate_nesting=max_plate_nesting) tmc_model = infer.config_enumerate(model, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) guide = factorized_guide if guide_type == "factorized" else \ nonfactorized_guide if guide_type == "nonfactorized" else \ lambda *args: None tmc_guide = infer.config_enumerate(guide, default="parallel", expand=expand, num_samples=num_samples, tmc=tmc_strategy) # convert to linear space for unbiasedness actual_loss = (-tmc.differentiable_loss(tmc_model, tmc_guide, reparameterized)).exp() actual_grads = grad(actual_loss, qs) # gold values from Funsor expected_grads = (torch.tensor({ 1: 0.0999, 2: 0.0860, 3: 0.0802, 4: 0.0771 }[depth]), ) grad_prec = 0.05 if reparameterized else 0.1 for actual_grad, expected_grad in zip(actual_grads, expected_grads): print(actual_loss) assert_equal(actual_grad, expected_grad, prec=grad_prec, msg="".join([ "\nexpected grad = {}".format( expected_grad.detach().cpu().numpy()), "\n actual grad = {}".format( actual_grad.detach().cpu().numpy()), ]))
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): pyro.clear_param_store() with pyro_backend("contrib.funsor"): with handlers.enum(): enum_model = infer.config_enumerate(model, default="parallel") # sequential factors trace = handlers.trace(enum_model).get_trace( weeks_data, days_data, history, False) # vectorized trace if use_replay: guide_trace = handlers.trace( _guide_from_model(model)).get_trace( weeks_data, days_data, history, True) vectorized_trace = handlers.trace( handlers.replay(model, trace=guide_trace)).get_trace( weeks_data, days_data, history, True) else: vectorized_trace = handlers.trace(enum_model).get_trace( weeks_data, days_data, history, True) factors = list() # sequential weeks factors for i in range(len(weeks_data)): for v in vars1: factors.append(trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) # sequential days factors for j in range(len(days_data)): for v in vars2: factors.append(trace.nodes["{}_{}".format( v, j)]["funsor"]["log_prob"]) vectorized_factors = list() # vectorized weeks factors for i in range(history): for v in vars1: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) for i in range(history, len(weeks_data)): for v in vars1: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, slice(history, len(weeks_data)))]["funsor"]["log_prob"](**{ "weeks": i - history }, **{ "{}_{}".format( k, slice(history - j, len(weeks_data) - j)): "{}_{}".format(k, i - j) for j in range(history + 1) for k in vars1 })) # vectorized days factors for i in range(history): for v in vars2: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) for i in range(history, len(days_data)): for v in vars2: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, slice(history, len(days_data)))]["funsor"]["log_prob"](**{ "days": i - history }, **{ "{}_{}".format( k, slice(history - j, len(days_data) - j)): "{}_{}".format(k, i - j) for j in range(history + 1) for k in vars2 })) # assert correct factors for f1, f2 in zip(factors, vectorized_factors): assert_close(f2, f1.align(tuple(f2.inputs))) # assert correct step expected_measure_vars = frozenset() actual_weeks_step = vectorized_trace.nodes["weeks"]["value"] # expected step: assume that all but the last var is markov expected_weeks_step = frozenset() for v in vars1[:-1]: v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1)) expected_weeks_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: expected_measure_vars |= frozenset(v_step) actual_days_step = vectorized_trace.nodes["days"]["value"] # expected step: assume that all but the last var is markov expected_days_step = frozenset() for v in vars2[:-1]: v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1)) expected_days_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: expected_measure_vars |= frozenset(v_step) assert actual_weeks_step == expected_weeks_step assert actual_days_step == expected_days_step # check measure_vars actual_measure_vars = terms_from_trace( vectorized_trace)["measure_vars"] assert actual_measure_vars == expected_measure_vars