def _get_matched_trace(self, model_trace, *args, **kwargs): """ :param model_trace: a trace from the model :type model_trace: pyro.poutine.trace_struct.Trace :returns: guide trace with sampled values matched to model_trace :rtype: pyro.poutine.trace_struct.Trace Returns a guide trace with values at sample and observe statements matched to those in model_trace. `args` and `kwargs` are passed to the guide. """ kwargs["observations"] = {} for node in itertools.chain(model_trace.stochastic_nodes, model_trace.observation_nodes): if "was_observed" in model_trace.nodes[node]["infer"]: model_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_trace.nodes[node]["value"] guide_trace = poutine.trace(poutine.replay(self.guide, model_trace)).get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) return guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator XXX support for automatically settings args/kwargs to volatile? """ for i in range(self.num_particles): if self.enum_discrete: raise NotImplementedError( "https://github.com/uber/pyro/issues/220") guide_trace = poutine.trace(guide, graph_type="dense").get_trace( *args, **kwargs) model_trace = poutine.trace(poutine.replay(model, guide_trace), graph_type="dense").get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate")] if enumerated_sites: warnings.warn('\n'.join([ 'Trace_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) # model_trace.compute_log_prob() # TODO: no va perque no hi ha parametres de decoder guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) yield model_trace, guide_trace
def get_importance_trace(graph_type, max_plate_nesting, model, guide, *args, **kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace( *args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type=graph_type).get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, max_plate_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, max_plate_nesting) for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, max_plate_nesting) return model_trace, guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide, graph_type="dense").get_trace( *args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="dense").get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [ name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate") ] if enumerated_sites: warnings.warn('\n'.join([ 'TraceGraph_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.' ])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator """ for i in range(self.num_particles): guide_trace = poutine.trace(guide, graph_type="dense").get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="dense").get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace) enumerated_sites = [name for name, site in guide_trace.nodes.items() if site["type"] == "sample" and site["infer"].get("enumerate")] if enumerated_sites: warnings.warn('\n'.join([ 'TraceGraph_ELBO found sample sites configured for enumeration:' ', '.join(enumerated_sites), 'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'])) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace
def _get_traces(self, model, guide, args, kwargs): if self.max_plate_nesting == float("inf"): with validation_enabled( False): # Avoid calling .log_prob() when undefined. # TODO factor this out as a stand-alone helper. ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs) vectorize = pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting) # Trace the guide as in ELBO. with poutine.trace() as tr, vectorize: guide(*args, **kwargs) guide_trace = tr.trace # Trace the model, drawing posterior predictive samples. with poutine.trace() as tr, poutine.uncondition(): with poutine.replay(trace=guide_trace), vectorize: model(*args, **kwargs) model_trace = tr.trace for site in model_trace.nodes.values(): if site["type"] == "sample" and site["infer"].get( "was_observed", False): site["is_observed"] = True if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): for site in guide_trace.nodes.values(): if site["type"] == "sample": warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): raise ValueError( "EnergyDistance requires fully reparametrized guides" ) for trace in model_trace.nodes.values(): if site["type"] == "sample": if site["is_observed"]: warn_if_nan(site["value"], site["name"]) if not getattr(site["fn"], "has_rsample", False): raise ValueError( "EnergyDistance requires reparametrized likelihoods" ) if self.prior_scale > 0: model_trace.compute_log_prob( site_filter=lambda name, site: not site["is_observed"]) if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": if not site["is_observed"]: check_site_shape(site, self.max_plate_nesting) return guide_trace, model_trace
def _get_matched_trace(model_trace, guide, args, kwargs): kwargs["observations"] = {} for node in model_trace.stochastic_nodes + model_trace.observation_nodes: if "was_observed" in model_trace.nodes[node]["infer"]: model_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_trace.nodes[node]["value"] guide_trace = poutine.trace(poutine.replay(guide, model_trace)).get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) return guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn( 'TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.' ) yield model_trace, guide_trace
def _get_matched_cross_trace(self, model_x_trace, model_z_trace,*args, **kwargs): kwargs["observations"] = {} kwargs["truth"] = {} for node in itertools.chain(model_x_trace.stochastic_nodes, model_x_trace.observation_nodes): if "was_observed" in model_x_trace.nodes[node]["infer"]: model_x_trace.nodes[node]["is_observed"] = True model_z_trace.nodes[node]["is_observed"] = True kwargs["observations"][node] = model_x_trace.nodes[node]["value"] else: kwargs["truth"][node] = model_x_trace.nodes[node]["value"] guide_trace = poutine.trace(poutine.replay(self.guide, model_z_trace) ).get_trace(*args, **kwargs) check_model_guide_match(model_x_trace, guide_trace) check_model_guide_match(model_z_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) return guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator XXX support for automatically settings args/kwargs to volatile? """ for i in range(self.num_particles): if self.enum_discrete: # This iterates over a bag of traces, for each particle. for scale, guide_trace in iter_discrete_traces( "flat", guide, *args, **kwargs): model_trace = poutine.trace( poutine.replay(model, guide_trace), graph_type="flat").get_trace(*args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) check_enum_discrete_can_run(model_trace, guide_trace) log_r = model_trace.batch_log_pdf( ) - guide_trace.batch_log_pdf() weight = scale / self.num_particles yield weight, model_trace, guide_trace, log_r continue guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace( *args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) log_r = model_trace.log_pdf() - guide_trace.log_pdf() weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace, log_r
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.') yield model_trace, guide_trace
def test_nested_autoguide(Elbo): class Model(PyroModule): def __init__(self): super().__init__() self.x_loc = nn.Parameter(torch.tensor(1.)) self.x_scale = PyroParam(torch.tensor(0.1), constraints.positive) def forward(self): pyro.sample("x", dist.Normal(self.x_loc, self.x_scale)) with pyro.plate("plate", 2): pyro.sample("y", dist.Normal(2., 0.1)) model = Model() guide = nested_auto_guide_callable(model) # Check master ref for all nested components. for _, m in guide.named_modules(): if m is guide: continue assert m.master is not None and m.master( ) is guide, "master ref wrong for {}".format(m._pyro_name) infer = SVI(model, guide, Adam({'lr': 0.005}), Elbo(strict_enumeration_warning=False)) for _ in range(20): infer.step() guide_trace = poutine.trace(guide).get_trace() model_trace = poutine.trace(model).get_trace() check_model_guide_match(model_trace, guide_trace) assert all( p.startswith("AutoGuideList.0") or p.startswith("AutoGuideList.1.z") for p in guide_trace.param_nodes) stochastic_nodes = set(guide_trace.stochastic_nodes) assert "x" in stochastic_nodes assert "y" in stochastic_nodes # Only latent sampled is for the IAF. assert "_AutoGuideList.1.z_latent" in stochastic_nodes
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a tracegraph generator XXX support for automatically settings args/kwargs to volatile? """ for i in range(self.num_particles): if self.enum_discrete: raise NotImplementedError("https://github.com/uber/pyro/issues/220") guide_trace = poutine.trace(guide, graph_type="dense").get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(model, guide_trace), graph_type="dense").get_trace(*args, **kwargs) check_model_guide_match(model_trace, guide_trace) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) weight = 1.0 / self.num_particles yield weight, model_trace, guide_trace