def test_compute_downstream_costs_iarange_reuse(dim1, dim2): guide_trace = poutine.trace(iarange_reuse_model_guide, graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2) model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace), graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute for k in dc: assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k]) expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'] expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum() expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'] expected_c1 += model_trace.nodes['obs']['log_prob'] assert_equal(expected_c1, dc['c1'])
def test_compute_downstream_costs_plate_reuse(dim1, dim2): guide_trace = poutine.trace(plate_reuse_model_guide, graph_type="dense").get_trace( include_obs=False, dim1=dim1, dim2=dim2) model_trace = poutine.trace(poutine.replay(plate_reuse_model_guide, trace=guide_trace), graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute for k in dc: assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k]) expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes[ 'c1']['log_prob'] expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum() expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes[ 'c2']['log_prob'] expected_c1 += model_trace.nodes['obs']['log_prob'] assert_equal(expected_c1, dc['c1'])
def test_compute_downstream_costs_irange_in_iarange(dim1, dim2): guide_trace = poutine.trace(nested_model_guide2, graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2) model_trace = poutine.trace(poutine.replay(nested_model_guide2, trace=guide_trace), graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute for k in dc: assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k]) expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob'] expected_b1 += model_trace.nodes['obs1']['log_prob'] assert_equal(expected_b1, dc['b1']) expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c']['log_prob'] for i in range(dim2): expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \ guide_trace.nodes['b{}'.format(i)]['log_prob'] expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob'] assert_equal(expected_c, dc['c']) expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob'] expected_a1 += expected_c.sum() assert_equal(expected_a1, dc['a1'])
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator XXX support for automatically settings args/kwargs to volatile? """ for i in range(self.num_particles): if self.enum_discrete: raise NotImplementedError( "https://github.com/uber/pyro/issues/220") guide_trace = poutine.trace(guide, graph_type="dense").get_trace( *args, **kwargs) model_trace = poutine.trace(poutine.replay(model, guide_trace), graph_type="dense").get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def _update_weights(self, model_trace, guide_trace): # w_t <-w_{t-1}*p(y_t|z_t) * p(z_t|z_t-1)/q(z_t) model_trace = prune_subsample_sites(model_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() for name, guide_site in guide_trace.nodes.items(): if guide_site["type"] == "sample": model_site = model_trace.nodes[name] log_p = model_site["log_prob"].reshape(self.num_particles, -1).sum(-1) log_q = guide_site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p - log_q for site in model_trace.nodes.values(): if site["type"] == "sample" and site["is_observed"]: log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p self.state._log_weights -= self.state._log_weights.max()
def _update_weights(self, model_trace, guide_trace): # w_t <-w_{t-1}*p(y_t|z_t) * p(z_t|z_t-1)/q(z_t) model_trace = prune_subsample_sites(model_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() for name, guide_site in guide_trace.nodes.items(): if guide_site["type"] == "sample": model_site = model_trace.nodes[name] log_p = model_site["log_prob"].reshape(self.num_particles, -1).sum(-1) log_q = guide_site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p - log_q if not (self.state._log_weights.max() > -math.inf): raise SMCFailed( "Failed to find feasible hypothesis after site {}". format(name)) for site in model_trace.nodes.values(): if site["type"] == "sample" and site["is_observed"]: log_p = site["log_prob"].reshape(self.num_particles, -1).sum(-1) self.state._log_weights += log_p if not (self.state._log_weights.max() > -math.inf): raise SMCFailed( "Failed to find feasible hypothesis after site {}". format(site["name"])) self.state._log_weights -= self.state._log_weights.max()
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide, graph_type="dense").get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="dense").get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate")] if enumerated_sites: warnings.warn('\n'.join([ 'TraceGraph_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def get_importance_trace(graph_type, max_plate_nesting, model, guide, *args, **kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace( *args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type=graph_type).get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, max_plate_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, max_plate_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, max_plate_nesting) return model_trace, guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate")] if enumerated_sites: warnings.warn('\n'.join([ 'Trace_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) # model_trace.compute_log_prob() # TODO: no va perque no hi ha parametres de decoder guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) yield model_trace, guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide, graph_type="dense").get_trace( *args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="dense").get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [ name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate") ] if enumerated_sites: warnings.warn('\n'.join([ 'TraceGraph_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.' ])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def _get_traces(self, model, guide, args, kwargs): if self.max_plate_nesting == float("inf"): with validation_enabled( False): # Avoid calling .log_prob() when undefined. # TODO factor this out as a stand-alone helper. ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs) vectorize = pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting) # Trace the guide as in ELBO. with poutine.trace() as tr, vectorize: guide(*args, **kwargs) guide_trace = tr.trace # Trace the model, drawing posterior predictive samples. with poutine.trace() as tr, poutine.uncondition(): with poutine.replay(trace=guide_trace), vectorize: model(*args, **kwargs) model_trace = tr.trace for site in model_trace.nodes.values(): if site["type"] == "sample" and site["infer"].get( "was_observed", False): site["is_observed"] = True if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): for site in guide_trace.nodes.values(): if site["type"] == "sample": warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): raise ValueError( "EnergyDistance requires fully reparametrized guides" ) for trace in model_trace.nodes.values(): if site["type"] == "sample": if site["is_observed"]: warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): raise ValueError( "EnergyDistance requires reparametrized likelihoods" ) if self.prior_scale > 0: model_trace.compute_log_prob( site_filter=lambda name, site: not site["is_observed"]) if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": if not site["is_observed"]: check_site_shape(site, self.max_plate_nesting) return guide_trace, model_trace
def test_compute_downstream_costs_plate_in_iplate(dim1): guide_trace = poutine.trace( nested_model_guide, graph_type="dense").get_trace(include_obs=False, dim1=dim1) model_trace = poutine.trace(poutine.replay(nested_model_guide, trace=guide_trace), graph_type="dense").get_trace(include_obs=True, dim1=dim1) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) expected_c1 += model_trace.nodes['obs1']['log_prob'] expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']) expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum() expected_b1 += model_trace.nodes['obs1']['log_prob'].sum() expected_c0 = (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']) expected_c0 += model_trace.nodes['obs0']['log_prob'] expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob']) expected_b0 += (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']).sum() expected_b0 += model_trace.nodes['obs0']['log_prob'].sum() assert_equal(expected_c1, dc['c1'], prec=1.0e-6) assert_equal(expected_b1, dc['b1'], prec=1.0e-6) assert_equal(expected_c0, dc['c0'], prec=1.0e-6) assert_equal(expected_b0, dc['b0'], prec=1.0e-6) for k in dc: assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k])
def test_compute_downstream_costs_duplicates(dim): guide_trace = poutine.trace(diamond_guide, graph_type="dense").get_trace(dim=dim) model_trace = poutine.trace(poutine.replay(diamond_model, trace=guide_trace), graph_type="dense").get_trace(dim=dim) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs( model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute expected_a1 = (model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']) for d in range(dim): expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob'] expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob'] expected_a1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) expected_a1 += model_trace.nodes['obs']['log_prob'] expected_b1 = -guide_trace.nodes['b1']['log_prob'] for d in range(dim): expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob'] expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) expected_b1 += model_trace.nodes['obs']['log_prob'] expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) for d in range(dim): expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob'] expected_c1 += model_trace.nodes['obs']['log_prob'] assert_equal(expected_a1, dc['a1'], prec=1.0e-6) assert_equal(expected_b1, dc['b1'], prec=1.0e-6) assert_equal(expected_c1, dc['c1'], prec=1.0e-6) for k in dc: assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k])
def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs): """ Assert that enumeration runs... """ with pyro_backend("pyro"): pyro.clear_param_store() if guide is None: guide = lambda **kwargs: None # noqa: E731 q_pyro, q_funsor = LifoQueue(), LifoQueue() q_pyro.put(Trace()) q_funsor.put(Trace()) while not q_pyro.empty() and not q_funsor.empty(): with pyro_backend("pyro"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_pyro = handlers.trace( handlers.queue( guide, q_pyro, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_pyro = handlers.trace( handlers.replay(model, trace=guide_tr_pyro)).get_trace(**kwargs) with pyro_backend("contrib.funsor"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_funsor = handlers.trace( handlers.queue( guide, q_funsor, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_funsor = handlers.trace( handlers.replay(model, trace=guide_tr_funsor)).get_trace(**kwargs) # make sure all dimensions were cleaned up assert _DIM_STACK.local_frame is _DIM_STACK.global_frame assert (not _DIM_STACK.global_frame.name_to_dim and not _DIM_STACK.global_frame.dim_to_name) assert _DIM_STACK.outermost is None tr_pyro = prune_subsample_sites(tr_pyro.copy()) tr_funsor = prune_subsample_sites(tr_funsor.copy()) _check_traces(tr_pyro, tr_funsor)
def _get_matched_trace(self, model_trace, *args, **kwargs): """ :param model_trace: a trace from the model :type model_trace: pyro.poutine.trace_struct.Trace :returns: guide trace with sampled values matched to model_trace :rtype: pyro.poutine.trace_struct.Trace Returns a guide trace with values at sample and observe statements matched to those in model_trace. `args` and `kwargs` are passed to the guide. """ kwargs["observations"] = {} for node in itertools.chain(model_trace.stochastic_nodes, model_trace.observation_nodes): if "was_observed" in model_trace.nodes[node]["infer"]: model_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_trace.nodes[node]["value"] guide_trace = poutine.trace(poutine.replay(self.guide, model_trace)).get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) return guide_trace
def forward(self, *args, **kwargs): samples = {} guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(*args, **kwargs) for site in prune_subsample_sites(model_trace).stochastic_nodes: samples[site] = model_trace.nodes[site]['value'] return tuple(v for _, v in sorted(samples.items()))
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = config_enumerate(self.model) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._discrete_sites = [] self._cond_indep_stacks = {} self._plates = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): if site["infer"].get("enumerate") != "parallel": raise NotImplementedError('Expected sample site "{}" to be discrete and ' 'configured for parallel enumeration'.format(name)) # collect discrete sample sites fn = site["fn"] Dist = type(fn) if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical): params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])] else: raise NotImplementedError("{} is not supported".format(Dist.__name__)) self._discrete_sites.append((site, Dist, params)) # collect independence contexts self._cond_indep_stacks[name] = site["cond_indep_stack"] for frame in site["cond_indep_stack"]: if frame.vectorized: self._plates[frame.name] = frame else: raise NotImplementedError("AutoDiscreteParallel does not support sequential pyro.plate")
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = config_enumerate(self.model, default="parallel") self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._discrete_sites = [] self._cond_indep_stacks = {} self._iaranges = {} for name, site in self.prototype_trace.nodes.items(): if site["type"] != "sample" or site["is_observed"]: continue if site["infer"].get("enumerate") != "parallel": raise NotImplementedError('Expected sample site "{}" to be discrete and ' 'configured for parallel enumeration'.format(name)) # collect discrete sample sites fn = site["fn"] Dist = type(fn) if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical): params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])] else: raise NotImplementedError("{} is not supported".format(Dist.__name__)) self._discrete_sites.append((site, Dist, params)) # collect independence contexts self._cond_indep_stacks[name] = site["cond_indep_stack"] for frame in site["cond_indep_stack"]: if frame.vectorized: self._iaranges[frame.name] = frame else: raise NotImplementedError("AutoDiscreteParallel does not support pyro.irange")
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn( 'TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.' ) yield model_trace, guide_trace
def test_compute_downstream_costs_duplicates(dim): guide_trace = poutine.trace(diamond_guide, graph_type="dense").get_trace(dim=dim) model_trace = poutine.trace(poutine.replay(diamond_model, trace=guide_trace), graph_type="dense").get_trace(dim=dim) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute expected_a1 = (model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']) for d in range(dim): expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob'] expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob'] expected_a1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) expected_a1 += model_trace.nodes['obs']['log_prob'] expected_b1 = - guide_trace.nodes['b1']['log_prob'] for d in range(dim): expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob'] expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) expected_b1 += model_trace.nodes['obs']['log_prob'] expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) for d in range(dim): expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob'] expected_c1 += model_trace.nodes['obs']['log_prob'] assert_equal(expected_a1, dc['a1'], prec=1.0e-6) assert_equal(expected_b1, dc['b1'], prec=1.0e-6) assert_equal(expected_c1, dc['c1'], prec=1.0e-6) for k in dc: assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k])
def _predictive(model, posterior_samples, num_samples, return_sites=(), return_trace=False, parallel=False, model_args=(), model_kwargs={}): max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) vectorize = pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1) model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*model_args, **model_kwargs)) reshaped_samples = {} for name, sample in posterior_samples.items(): sample_shape = sample.shape[1:] sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape) reshaped_samples[name] = sample if return_trace: trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\ .get_trace(*model_args, **model_kwargs) return trace return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape) site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape # non-empty return-sites if return_sites: if site in return_sites: return_site_shapes[site] = site_shape # special case (for guides): include all sites elif return_sites is None: return_site_shapes[site] = site_shape # default case: return sites = () # include all sites not in posterior samples elif site not in posterior_samples: return_site_shapes[site] = site_shape # handle _RETURN site if return_sites is not None and '_RETURN' in return_sites: value = model_trace.nodes['_RETURN']['value'] shape = (num_samples,) + value.shape if torch.is_tensor(value) else None return_site_shapes['_RETURN'] = shape if not parallel: return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace=False) trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\ .get_trace(*model_args, **model_kwargs) predictions = {} for site, shape in return_site_shapes.items(): value = trace.nodes[site]['value'] if site == '_RETURN' and shape is None: predictions[site] = value continue if value.numel() < reduce((lambda x, y: x * y), shape): predictions[site] = value.expand(shape) else: predictions[site] = value.reshape(shape) return predictions
def test_compute_downstream_costs_iarange_in_irange(dim1): guide_trace = poutine.trace(nested_model_guide, graph_type="dense").get_trace(include_obs=False, dim1=dim1) model_trace = poutine.trace(poutine.replay(nested_model_guide, trace=guide_trace), graph_type="dense").get_trace(include_obs=True, dim1=dim1) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) expected_c1 += model_trace.nodes['obs1']['log_prob'] expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']) expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum() expected_b1 += model_trace.nodes['obs1']['log_prob'].sum() expected_c0 = (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']) expected_c0 += model_trace.nodes['obs0']['log_prob'] expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob']) expected_b0 += (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']).sum() expected_b0 += model_trace.nodes['obs0']['log_prob'].sum() assert_equal(expected_c1, dc['c1'], prec=1.0e-6) assert_equal(expected_b1, dc['b1'], prec=1.0e-6) assert_equal(expected_c0, dc['c0'], prec=1.0e-6) assert_equal(expected_b0, dc['b0'], prec=1.0e-6) for k in dc: assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k])
def _guess_max_plate_nesting(self, model, guide, args, kwargs): """ Guesses max_plate_nesting by running the (model,guide) pair once without enumeration. This optimistically assumes static model structure. """ # Ignore validation to allow model-enumerated sites absent from the guide. with poutine.block(): guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace( poutine.replay(model, trace=guide_trace) ).get_trace(*args, **kwargs) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) sites = [ site for trace in (model_trace, guide_trace) for site in trace.nodes.values() if site["type"] == "sample" ] # Validate shapes now, since shape constraints will be weaker once # max_plate_nesting is changed from float('inf') to some finite value. # Here we know the traces are not enumerated, but later we'll need to # allow broadcasting of dims to the left of max_plate_nesting. if is_validation_enabled(): guide_trace.compute_log_prob() model_trace.compute_log_prob() for site in sites: check_site_shape(site, max_plate_nesting=float("inf")) dims = [ frame.dim for site in sites for frame in site["cond_indep_stack"] if frame.vectorized ] self.max_plate_nesting = -min(dims) if dims else 0 if self.vectorize_particles and self.num_particles > 1: self.max_plate_nesting += 1 logging.info("Guessed max_plate_nesting = {}".format(self.max_plate_nesting))
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator XXX support for automatically settings args/kwargs to volatile? """ for i in range(self.num_particles): if self.enum_discrete: # This iterates over a bag of traces, for each particle. for scale, guide_trace in iter_discrete_traces( "flat", guide, *args, **kwargs): model_trace = poutine.trace( poutine.replay(model, guide_trace), graph_type="flat").get_trace(*args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) check_enum_discrete_can_run(model_trace, guide_trace) log_r = model_trace.batch_log_pdf( ) - guide_trace.batch_log_pdf() weight = scale / self.num_particles yield weight, model_trace, guide_trace, log_r continue guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) log_r = model_trace.log_pdf() - guide_trace.log_pdf() weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace, log_r
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.') yield model_trace, guide_trace
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): # For internal use by infer_discrete. # Create an enumerated trace. with poutine.block(), EnumMessenger(first_available_dim): enum_trace = poutine.trace(model).get_trace(*args, **kwargs) enum_trace = prune_subsample_sites(enum_trace) enum_trace.compute_log_prob() enum_trace.pack_tensors() return _sample_posterior_from_trace(model, enum_trace, temperature, *args, **kwargs)
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = InitMessenger(self.init)(self.model) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) for name, site in self.prototype_trace.iter_stochastic_nodes(): for frame in site["cond_indep_stack"]: if not frame.vectorized: raise NotImplementedError( "EasyGuide does not support sequential pyro.plate") self.frames[frame.name] = frame
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator XXX support for automatically settings args/kwargs to volatile? """ for i in range(self.num_particles): if self.enum_discrete: raise NotImplementedError("https://github.com/uber/pyro/issues/220") guide_trace = poutine.trace(guide, graph_type="dense").get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, guide_trace), graph_type="dense").get_trace(*args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def _get_matched_trace(model_trace, guide, args, kwargs): kwargs["observations"] = {} for node in model_trace.stochastic_nodes + model_trace.observation_nodes: if "was_observed" in model_trace.nodes[node]["infer"]: model_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_trace.nodes[node]["value"] guide_trace = poutine.trace(poutine.replay(guide, model_trace)).get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) return guide_trace
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure self.prototype_trace = poutine.block(poutine.trace(self.model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._plates = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): for frame in site["cond_indep_stack"]: if frame.vectorized: self._plates[frame.name] = frame else: raise NotImplementedError("AutoGuideList does not support sequential pyro.plate")
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure self.prototype_trace = poutine.block(poutine.trace(self.model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._iaranges = {} for name, site in self.prototype_trace.nodes.items(): if site["type"] != "sample" or site["is_observed"]: continue for frame in site["cond_indep_stack"]: if frame.vectorized: self._iaranges[frame.name] = frame else: raise NotImplementedError("AutoGuideList does not support pyro.irange")
def test_predictive(auto_class): N, D = 3, 2 class RandomLinear(nn.Linear, PyroModule): def __init__(self, in_features, out_features): super().__init__(in_features, out_features) self.weight = PyroSample( dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)) self.bias = PyroSample( dist.Normal(0., 10.).expand([out_features]).to_event(1)) class LinearRegression(PyroModule): def __init__(self): super().__init__() self.linear = RandomLinear(D, 1) def forward(self, x, y=None): mean = self.linear(x).squeeze(-1) sigma = pyro.sample("sigma", dist.LogNormal(0., 1.)) with pyro.plate('plate', N): return pyro.sample('obs', dist.Normal(mean, sigma), obs=y) x, y = torch.randn(N, D), torch.randn(N) model = LinearRegression() guide = auto_class(model) # XXX: Record `y` as observed in the prototype trace # Is there a better pattern to follow? guide(x, y=y) # Test predictive module model_trace = poutine.trace(model).get_trace(x, y=None) predictive = Predictive(model, guide=guide, num_samples=10) pyro.set_rng_seed(0) samples = predictive(x) for site in prune_subsample_sites(model_trace).stochastic_nodes: assert site in samples with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) traced_predictive = torch.jit.trace_module(predictive, {"call": (x, )}) f = io.BytesIO() torch.jit.save(traced_predictive, f) f.seek(0) predictive_deser = torch.jit.load(f) pyro.set_rng_seed(0) samples_deser = predictive_deser.call(x) # Note that the site values are different in the serialized guide assert len(samples) == len(samples_deser)
def forward(self, *args, **kwargs): total_samples = {} for i in range(self.num_samples): if i % 50 == 0: print("done with {}".format(i)) guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace( *args, **kwargs) for site in prune_subsample_sites(model_trace).stochastic_nodes: if site not in total_samples: total_samples[site] = [] total_samples[site].append(model_trace.nodes[site]["value"]) for key in total_samples.keys(): total_samples[key] = torch.stack(total_samples[key]) return total_samples
def _get_matched_cross_trace(self, model_x_trace, model_z_trace,*args, **kwargs): kwargs["observations"] = {} kwargs["truth"] = {} for node in itertools.chain(model_x_trace.stochastic_nodes, model_x_trace.observation_nodes): if "was_observed" in model_x_trace.nodes[node]["infer"]: model_x_trace.nodes[node]["is_observed"] = True model_z_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_x_trace.nodes[node]["value"] else: kwargs["truth"][node] = model_x_trace.nodes[node]["value"] guide_trace = poutine.trace(poutine.replay(self.guide, model_z_trace) ).get_trace(*args, **kwargs) check_model_guide_match(model_x_trace, guide_trace) check_model_guide_match(model_z_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) return guide_trace
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23, include_triple, include_z1): guide_trace = poutine.trace(big_model_guide, graph_type="dense").get_trace(include_obs=False, include_inner_1=include_inner_1, include_single=include_single, flip_c23=flip_c23, include_triple=include_triple, include_z1=include_z1) model_trace = poutine.trace(poutine.replay(big_model_guide, trace=guide_trace), graph_type="dense").get_trace(include_obs=True, include_inner_1=include_inner_1, include_single=include_single, flip_c23=flip_c23, include_triple=include_triple, include_z1=include_z1) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_log_prob() non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes) dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes) assert dc_nodes == dc_nodes_brute expected_nodes_full_model = {'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'}, 'd2': {'obs', 'd2'}, 'd1': {'obs', 'd1', 'd2'}, 'c3': {'d2', 'obs', 'd1', 'c3'}, 'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'}, 'b1': {'obs', 'b1', 'd1', 'd2', 'c3', 'c1', 'c2'}, 'c1': {'d1', 'c1', 'obs', 'd2', 'c3', 'c2'}, 'c2': {'obs', 'd1', 'c3', 'd2', 'c2'}} if not include_triple and include_inner_1 and include_single and not flip_c23: assert(dc_nodes == expected_nodes_full_model) expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']) expected_b1 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) expected_b1 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) expected_b1 += model_trace.nodes['obs']['log_prob'].sum(0, keepdim=False) if include_inner_1: expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum(0) expected_b1 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum(0) expected_b1 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum(0) assert_equal(expected_b1, dc['b1'], prec=1.0e-6) if include_single: expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob']) expected_b0 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum() expected_b0 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum() expected_b0 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum() expected_b0 += model_trace.nodes['obs']['log_prob'].sum() if include_inner_1: expected_b0 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum() expected_b0 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum() expected_b0 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum() assert_equal(expected_b0, dc['b0'], prec=1.0e-6) assert dc['b0'].size() == (5,) if include_inner_1: expected_c3 = (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']) expected_c3 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) expected_c3 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) expected_c3 += model_trace.nodes['obs']['log_prob'].sum(0) expected_c2 = (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']) expected_c2 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0) expected_c2 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0) expected_c2 += model_trace.nodes['obs']['log_prob'].sum(0) expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']) if flip_c23: expected_c3 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'] expected_c2 += model_trace.nodes['c3']['log_prob'] else: expected_c2 += model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob'] expected_c2 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'] expected_c1 += expected_c3 assert_equal(expected_c1, dc['c1'], prec=1.0e-6) assert_equal(expected_c2, dc['c2'], prec=1.0e-6) assert_equal(expected_c3, dc['c3'], prec=1.0e-6) expected_d1 = model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob'] expected_d1 += model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob'] expected_d1 += model_trace.nodes['obs']['log_prob'] expected_d2 = (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']) expected_d2 += model_trace.nodes['obs']['log_prob'] if include_triple: expected_z0 = dc['a1'] + model_trace.nodes['z0']['log_prob'] - guide_trace.nodes['z0']['log_prob'] assert_equal(expected_z0, dc['z0'], prec=1.0e-6) assert_equal(expected_d2, dc['d2'], prec=1.0e-6) assert_equal(expected_d1, dc['d1'], prec=1.0e-6) assert dc['b1'].size() == (2,) assert dc['d2'].size() == (4, 2) for k in dc: assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size()) assert_equal(dc[k], dc_brute[k])
def _predictive(model, posterior_samples, num_samples, return_sites=None, return_trace=False, parallel=False, model_args=(), model_kwargs={}): max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) model_trace = prune_subsample_sites( poutine.trace(model).get_trace(*model_args, **model_kwargs)) reshaped_samples = {} for name, sample in posterior_samples.items(): sample_shape = sample.shape[1:] sample = sample.reshape((num_samples, ) + (1, ) * (max_plate_nesting - len(sample_shape)) + sample_shape) reshaped_samples[name] = sample def _vectorized_fn(fn): """ Wraps a callable inside an outermost :class:`~pyro.plate` to parallelize sampling from the posterior predictive. :param fn: arbitrary callable containing Pyro primitives. :return: wrapped callable. """ def wrapped_fn(*args, **kwargs): with pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1): return fn(*args, **kwargs) return wrapped_fn if return_trace: trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\ .get_trace(*model_args, **model_kwargs) return trace return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: site_shape = (num_samples, ) + model_trace.nodes[site]['value'].shape if isinstance(return_sites, (list, tuple, set)): if site in return_sites: return_site_shapes[site] = site_shape else: if (return_sites is not None) or (site not in reshaped_samples): return_site_shapes[site] = site_shape # handle _RETURN site if isinstance(return_sites, (list, tuple, set)) and '_RETURN' in return_sites: value = model_trace.nodes['_RETURN']['value'] shape = (num_samples, ) + value.shape if torch.is_tensor( value) else None return_site_shapes['_RETURN'] = shape if not parallel: return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes.keys(), return_trace=False) trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\ .get_trace(*model_args, **model_kwargs) predictions = {} for site, shape in return_site_shapes.items(): value = trace.nodes[site]['value'] if site == '_RETURN' and shape is None: predictions[site] = value continue if value.numel() < reduce((lambda x, y: x * y), shape): predictions[site] = value.expand(shape) else: predictions[site] = value.reshape(shape) return predictions
def predictive(model, posterior_samples, *args, **kwargs): """ .. warning:: This function is deprecated and will be removed in a future release. Use the :class:`~pyro.infer.predictive.Predictive` class instead. Run model by sampling latent parameters from `posterior_samples`, and return values at sample sites from the forward run. By default, only sites not contained in `posterior_samples` are returned. This can be modified by changing the `return_sites` keyword argument. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param args: model arguments. :param kwargs: model kwargs; and other keyword arguments (see below). :Keyword Arguments: * **num_samples** (``int``) - number of samples to draw from the predictive distribution. This argument has no effect if ``posterior_samples`` is non-empty, in which case, the leading dimension size of samples in ``posterior_samples`` is used. * **return_sites** (``list``) - sites to return; by default only sample sites not present in `posterior_samples` are returned. * **return_trace** (``bool``) - whether to return the full trace. Note that this is vectorized over `num_samples`. * **parallel** (``bool``) - predict in parallel by wrapping the existing model in an outermost `plate` messenger. Note that this requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`. :return: dict of samples from the predictive distribution, or a single vectorized `trace` (if `return_trace=True`). """ warnings.warn('The `mcmc.predictive` function is deprecated and will be removed in ' 'a future release. Use the `pyro.infer.Predictive` class instead.', FutureWarning) num_samples = kwargs.pop('num_samples', None) return_sites = kwargs.pop('return_sites', None) return_trace = kwargs.pop('return_trace', False) parallel = kwargs.pop('parallel', False) max_plate_nesting = _guess_max_plate_nesting(model, args, kwargs) model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*args, **kwargs)) reshaped_samples = {} for name, sample in posterior_samples.items(): batch_size, sample_shape = sample.shape[0], sample.shape[1:] if num_samples is None: num_samples = batch_size elif num_samples != batch_size: warnings.warn("Sample's leading dimension size {} is different from the " "provided {} num_samples argument. Defaulting to {}." .format(batch_size, num_samples, batch_size), UserWarning) num_samples = batch_size sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape) reshaped_samples[name] = sample if num_samples is None: raise ValueError("No sample sites in model to infer `num_samples`.") return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: site_shape = (num_samples,) + model_trace.nodes[site]['value'].shape if return_sites: if site in return_sites: return_site_shapes[site] = site_shape else: if site not in reshaped_samples: return_site_shapes[site] = site_shape if not parallel: return _predictive_sequential(model, posterior_samples, args, kwargs, num_samples, return_site_shapes.keys(), return_trace) def _vectorized_fn(fn): """ Wraps a callable inside an outermost :class:`~pyro.plate` to parallelize sampling from the posterior predictive. :param fn: arbitrary callable containing Pyro primitives. :return: wrapped callable. """ def wrapped_fn(*args, **kwargs): with pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1): return fn(*args, **kwargs) return wrapped_fn trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\ .get_trace(*args, **kwargs) if return_trace: return trace predictions = {} for site, shape in return_site_shapes.items(): value = trace.nodes[site]['value'] if value.numel() < reduce((lambda x, y: x * y), shape): predictions[site] = value.expand(shape) else: predictions[site] = value.reshape(shape) return predictions
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): # For internal use by infer_discrete. # Create an enumerated trace. with poutine.block(), EnumerateMessenger(first_available_dim): enum_trace = poutine.trace(model).get_trace(*args, **kwargs) enum_trace = prune_subsample_sites(enum_trace) enum_trace.compute_log_prob() enum_trace.pack_tensors() plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. log_probs = OrderedDict() sum_dims = set() queries = [] for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized) log_prob = node["packed"]["log_prob"] log_probs.setdefault(ordinal, []).append(log_prob) sum_dims.update(log_prob._pyro_dims) for frame in node["cond_indep_stack"]: if frame.vectorized: sum_dims.remove(plate_to_symbol[frame.name]) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # Run forward-backward algorithm, collecting the ordinal of each connected component. ring = _make_ring(temperature) log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not f.vectorized or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)