def _get_trace(self, model, guide, args, kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ model_trace, guide_trace = get_importance_trace( "flat", self.max_plate_nesting, model, guide, args, kwargs) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) _check_tmc_elbo_constraint(model_trace, guide_trace) has_enumerated_sites = any(site["infer"].get("enumerate") for trace in (guide_trace, model_trace) for name, site in trace.nodes.items() if site["type"] == "sample") if self.strict_enumeration_warning and not has_enumerated_sites: warnings.warn( 'TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.' ) guide_trace.pack_tensors() model_trace.pack_tensors(guide_trace.plate_to_symbol) return model_trace, guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace( *args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn( 'TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.' ) yield model_trace, guide_trace
def _get_traces(self, model, guide, *args, **kwargs): """ runs the guide and runs the model against the guide with the result packaged as a trace generator """ # enable parallel enumeration guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting) for i in range(self.num_particles): for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs): model_trace = poutine.trace(poutine.replay(model, trace=guide_trace), graph_type="flat").get_trace(*args, **kwargs) if is_validation_enabled(): check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) if is_validation_enabled(): check_traceenum_requirements(model_trace, guide_trace) model_trace.compute_log_prob() guide_trace.compute_score_parts() if is_validation_enabled(): for site in model_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) any_enumerated = False for site in guide_trace.nodes.values(): if site["type"] == "sample": check_site_shape(site, self.max_iarange_nesting) if site["infer"].get("enumerate"): any_enumerated = True if self.strict_enumeration_warning and not any_enumerated: warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. ' 'If you want to enumerate sites, you need to @config_enumerate or set ' 'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? ' 'If you do not want to enumerate, consider using Trace_ELBO instead.') yield model_trace, guide_trace
def _check_traces(tr_pyro, tr_funsor): assert tr_pyro.nodes.keys() == tr_funsor.nodes.keys() tr_pyro.compute_log_prob() tr_funsor.compute_log_prob() tr_pyro.pack_tensors() symbol_to_name = { node['infer']['_enumerate_symbol']: name for name, node in tr_pyro.nodes.items() if node['type'] == 'sample' and not node['is_observed'] and node['infer'].get('enumerate') == 'parallel' } symbol_to_name.update({ symbol: name for name, symbol in tr_pyro.plate_to_symbol.items()}) if _NAMED_TEST_STRENGTH >= 1: # coarser check: enumeration requirements satisfied check_traceenum_requirements(tr_pyro, Trace()) check_traceenum_requirements(tr_funsor, Trace()) try: # coarser check: number of elements and squeezed shapes for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] assert pyro_node['packed']['log_prob'].numel() == funsor_node['log_prob'].numel() assert pyro_node['packed']['log_prob'].shape == funsor_node['log_prob'].squeeze().shape assert frozenset(f for f in pyro_node['cond_indep_stack'] if f.vectorized) == \ frozenset(f for f in funsor_node['cond_indep_stack'] if f.vectorized) except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] pyro_packed_shape = pyro_node['packed']['log_prob'].shape funsor_packed_shape = funsor_node['log_prob'].squeeze().shape if pyro_packed_shape != funsor_packed_shape: err_str = "==> (dep mismatch) {}".format(name) else: err_str = name print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_packed_shape, funsor_packed_shape)) raise if _NAMED_TEST_STRENGTH >= 2: try: # medium check: unordered packed shapes match for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims) funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs) assert pyro_names == frozenset(name.replace('__PARTICLES', '') for name in funsor_names) except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims) funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs) if pyro_names != funsor_names: err_str = "==> (packed mismatch) {}".format(name) else: err_str = name print(err_str, "Pyro: {} vs Funsor: {}".format(sorted(tuple(pyro_names)), sorted(tuple(funsor_names)))) raise if _NAMED_TEST_STRENGTH >= 3: try: # finer check: exact match with unpacked Pyro shapes for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] assert pyro_node['log_prob'].shape == funsor_node['log_prob'].shape assert pyro_node['value'].shape == funsor_node['value'].shape except AssertionError: for name, pyro_node in tr_pyro.nodes.items(): if pyro_node['type'] != 'sample': continue funsor_node = tr_funsor.nodes[name] pyro_shape = pyro_node['log_prob'].shape funsor_shape = funsor_node['log_prob'].shape if pyro_shape != funsor_shape: err_str = "==> (unpacked mismatch) {}".format(name) else: err_str = name print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_shape, funsor_shape)) raise