def test_gmm_iter_discrete_traces(data_size, graph_type, model): pyro.clear_param_store() data = torch.arange(0, data_size) model = config_enumerate(model) traces = list(iter_discrete_traces(graph_type, model, data=data, verbose=True)) # This non-vectorized version is exponential in data_size: assert len(traces) == 2**data_size
def test_gmm_batch_iter_discrete_traces(model, data_size, graph_type): pyro.clear_param_store() data = torch.arange(0, data_size) model = config_enumerate(model) traces = list(iter_discrete_traces(graph_type, model, data=data)) # This vectorized version is independent of data_size: assert len(traces) == 2
def test_iter_discrete_traces_vector(graph_type): pyro.clear_param_store() def model(): p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]]))) ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]))) x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(ps, one_hot=False)) assert x.size() == (2, 1) assert y.size() == (2, 1) return dict(x=x, y=y) traces = list(iter_discrete_traces(graph_type, model)) p = pyro.param("p").data ps = pyro.param("ps").data assert len(traces) == 2 * ps.size(-1) for scale, trace in traces: x = trace.nodes["x"]["value"].data.squeeze().long()[0] y = trace.nodes["y"]["value"].data.squeeze().long()[0] expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) * dist.Categorical(ps, one_hot=False).log_pdf(y)) expected_scale = expected_scale.data.view(-1)[0] assert_equal(scale, expected_scale)
def test_iter_discrete_traces_vector(graph_type): pyro.clear_param_store() def model(): p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]]))) ps = pyro.param( "ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]))) x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(ps, one_hot=False)) assert x.size() == (2, 1) assert y.size() == (2, 1) return dict(x=x, y=y) traces = list(iter_discrete_traces(graph_type, model)) p = pyro.param("p").data ps = pyro.param("ps").data assert len(traces) == 2 * ps.size(-1) for scale, trace in traces: x = trace.nodes["x"]["value"].data.squeeze().long()[0] y = trace.nodes["y"]["value"].data.squeeze().long()[0] expected_scale = torch.exp( dist.Bernoulli(p).log_pdf(x) * dist.Categorical(ps, one_hot=False).log_pdf(y)) expected_scale = expected_scale.data.view(-1)[0] assert_equal(scale, expected_scale)
def test_iter_discrete_traces_order(depth, graph_type): @config_enumerate def model(depth): for i in range(depth): pyro.sample("x{}".format(i), dist.Bernoulli(0.5)) traces = list(iter_discrete_traces(graph_type, model, depth)) assert len(traces) == 2 ** depth for trace in traces: sites = [name for name, site in trace.nodes.items() if site["type"] == "sample"] assert sites == ["x{}".format(i) for i in range(depth)]
def test_iter_discrete_traces_scalar(graph_type): pyro.clear_param_store() @config_enumerate def model(): p = pyro.param("p", torch.tensor(0.05)) probs = pyro.param("probs", torch.tensor([0.1, 0.2, 0.3, 0.4])) x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(probs)) return dict(x=x, y=y) traces = list(iter_discrete_traces(graph_type, model)) probs = pyro.param("probs") assert len(traces) == 2 * len(probs)
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn( 'TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.' ) yield model_trace, guide_trace
def test_iter_discrete_traces_vector(graph_type): pyro.clear_param_store() @config_enumerate def model(): p = pyro.param("p", torch.tensor([0.05, 0.15])) probs = pyro.param("probs", torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]])) with pyro.iarange("iarange", 2): x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(probs)) assert x.size() == (2,) assert y.size() == (2,) return dict(x=x, y=y) traces = list(iter_discrete_traces(graph_type, model)) probs = pyro.param("probs") assert len(traces) == 2 * probs.size(-1)
def test_iter_discrete_traces_scalar(graph_type): pyro.clear_param_store() def model(): p = pyro.param("p", Variable(torch.Tensor([0.05]))) ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4]))) x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(ps, one_hot=False)) return dict(x=x, y=y) traces = list(iter_discrete_traces(graph_type, model)) p = pyro.param("p").data ps = pyro.param("ps").data assert len(traces) == 2 * len(ps) for scale, trace in traces: x = trace.nodes["x"]["value"].data.long().view(-1)[0] y = trace.nodes["y"]["value"].data.long().view(-1)[0] expected_scale = Variable(torch.Tensor([[1 - p[0], p[0]][x] * ps[y]])) assert_equal(scale, expected_scale)
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator XXX support for automatically settings args/kwargs to volatile? """ for i in range(self.num_particles): if self.enum_discrete: # This iterates over a bag of traces, for each particle. for scale, guide_trace in iter_discrete_traces( "flat", guide, *args, **kwargs): model_trace = poutine.trace( poutine.replay(model, guide_trace), graph_type="flat").get_trace(*args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) check_enum_discrete_can_run(model_trace, guide_trace) log_r = model_trace.batch_log_pdf( ) - guide_trace.batch_log_pdf() weight = scale / self.num_particles yield weight, model_trace, guide_trace, log_r continue guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) log_r = model_trace.log_pdf() - guide_trace.log_pdf() weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace, log_r
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.') yield model_trace, guide_trace
def lp_fn(input_dict): excluded_nodes = set(["_INPUT", "_RETURN"]) for key, value in input_dict.items(): model_trace.nodes[key]["value"] = value replayed_model = pyro.poutine.replay(model, model_trace) log_p = 0 for trace_enum in iter_discrete_traces("flat", fn=replayed_model): trace_enum.compute_log_prob() for node_name, node in trace_enum.nodes.items(): if node_name in excluded_nodes: continue if node["log_prob"].ndim == 1: log_p += trace_enum.nodes[node_name]["log_prob"] else: log_p += trace_enum.nodes[node_name]["log_prob"].sum( dim=1) return log_p