Example #1
0
    def _get_trace(self, model, guide, args, kwargs):
        """
        Returns a single trace from the guide, and the model that is run
        against it.
        """
        model_trace, guide_trace = get_importance_trace(
            "flat", self.max_plate_nesting, model, guide, args, kwargs)

        if is_validation_enabled():
            check_traceenum_requirements(model_trace, guide_trace)
            _check_tmc_elbo_constraint(model_trace, guide_trace)

            has_enumerated_sites = any(site["infer"].get("enumerate")
                                       for trace in (guide_trace, model_trace)
                                       for name, site in trace.nodes.items()
                                       if site["type"] == "sample")

            if self.strict_enumeration_warning and not has_enumerated_sites:
                warnings.warn(
                    'TraceEnum_ELBO found no sample sites configured for enumeration. '
                    'If you want to enumerate sites, you need to @config_enumerate or set '
                    'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                    'If you do not want to enumerate, consider using Trace_ELBO instead.'
                )

        guide_trace.pack_tensors()
        model_trace.pack_tensors(guide_trace.plate_to_symbol)
        return model_trace, guide_trace
Example #2
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide,
                             first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args,
                                                    **kwargs):
                model_trace = poutine.trace(poutine.replay(model,
                                                           trace=guide_trace),
                                            graph_type="flat").get_trace(
                                                *args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace,
                                            self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn(
                            'TraceEnum_ELBO found no sample sites configured for enumeration. '
                            'If you want to enumerate sites, you need to @config_enumerate or set '
                            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                            'If you do not want to enumerate, consider using Trace_ELBO instead.'
                        )

                yield model_trace, guide_trace
Example #3
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
                model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                            graph_type="flat").get_trace(*args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
                                      'If you want to enumerate sites, you need to @config_enumerate or set '
                                      'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                                      'If you do not want to enumerate, consider using Trace_ELBO instead.')

                yield model_trace, guide_trace
Example #4
0
def _check_traces(tr_pyro, tr_funsor):

    assert tr_pyro.nodes.keys() == tr_funsor.nodes.keys()
    tr_pyro.compute_log_prob()
    tr_funsor.compute_log_prob()
    tr_pyro.pack_tensors()

    symbol_to_name = {
        node['infer']['_enumerate_symbol']: name
        for name, node in tr_pyro.nodes.items()
        if node['type'] == 'sample' and not node['is_observed']
        and node['infer'].get('enumerate') == 'parallel'
    }
    symbol_to_name.update({
        symbol: name for name, symbol in tr_pyro.plate_to_symbol.items()})

    if _NAMED_TEST_STRENGTH >= 1:
        # coarser check: enumeration requirements satisfied
        check_traceenum_requirements(tr_pyro, Trace())
        check_traceenum_requirements(tr_funsor, Trace())
        try:
            # coarser check: number of elements and squeezed shapes
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                assert pyro_node['packed']['log_prob'].numel() == funsor_node['log_prob'].numel()
                assert pyro_node['packed']['log_prob'].shape == funsor_node['log_prob'].squeeze().shape
                assert frozenset(f for f in pyro_node['cond_indep_stack'] if f.vectorized) == \
                    frozenset(f for f in funsor_node['cond_indep_stack'] if f.vectorized)
        except AssertionError:
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                pyro_packed_shape = pyro_node['packed']['log_prob'].shape
                funsor_packed_shape = funsor_node['log_prob'].squeeze().shape
                if pyro_packed_shape != funsor_packed_shape:
                    err_str = "==> (dep mismatch) {}".format(name)
                else:
                    err_str = name
                print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_packed_shape, funsor_packed_shape))
            raise

    if _NAMED_TEST_STRENGTH >= 2:
        try:
            # medium check: unordered packed shapes match
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims)
                funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs)
                assert pyro_names == frozenset(name.replace('__PARTICLES', '') for name in funsor_names)
        except AssertionError:
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                pyro_names = frozenset(symbol_to_name[d] for d in pyro_node['packed']['log_prob']._pyro_dims)
                funsor_names = frozenset(funsor_node['funsor']['log_prob'].inputs)
                if pyro_names != funsor_names:
                    err_str = "==> (packed mismatch) {}".format(name)
                else:
                    err_str = name
                print(err_str, "Pyro: {} vs Funsor: {}".format(sorted(tuple(pyro_names)), sorted(tuple(funsor_names))))
            raise

    if _NAMED_TEST_STRENGTH >= 3:
        try:
            # finer check: exact match with unpacked Pyro shapes
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                assert pyro_node['log_prob'].shape == funsor_node['log_prob'].shape
                assert pyro_node['value'].shape == funsor_node['value'].shape
        except AssertionError:
            for name, pyro_node in tr_pyro.nodes.items():
                if pyro_node['type'] != 'sample':
                    continue
                funsor_node = tr_funsor.nodes[name]
                pyro_shape = pyro_node['log_prob'].shape
                funsor_shape = funsor_node['log_prob'].shape
                if pyro_shape != funsor_shape:
                    err_str = "==> (unpacked mismatch) {}".format(name)
                else:
                    err_str = name
                print(err_str, "Pyro: {} vs Funsor: {}".format(pyro_shape, funsor_shape))
            raise