コード例 #1
0
    def _get_trace(self, model, guide, args, kwargs):
        """
        Returns a single trace from the guide, and the model that is run
        against it.
        """
        model_trace, guide_trace = get_importance_trace(
            "flat", self.max_plate_nesting, model, guide, args, kwargs)

        if is_validation_enabled():
            check_traceenum_requirements(model_trace, guide_trace)
            _check_tmc_elbo_constraint(model_trace, guide_trace)

            has_enumerated_sites = any(site["infer"].get("enumerate")
                                       for trace in (guide_trace, model_trace)
                                       for name, site in trace.nodes.items()
                                       if site["type"] == "sample")

            if self.strict_enumeration_warning and not has_enumerated_sites:
                warnings.warn(
                    'TraceEnum_ELBO found no sample sites configured for enumeration. '
                    'If you want to enumerate sites, you need to @config_enumerate or set '
                    'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                    'If you do not want to enumerate, consider using Trace_ELBO instead.'
                )

        guide_trace.pack_tensors()
        model_trace.pack_tensors(guide_trace.plate_to_symbol)
        return model_trace, guide_trace
コード例 #2
0
ファイル: infer.py プロジェクト: dimarkov/pybefit
    def get_log_evidence_per_subject(self,
                                     *args,
                                     num_particles=100,
                                     max_plate_nesting=1,
                                     **kwargs):
        """Return subject specific log model evidence"""

        model = self.model
        guide = self.guide

        elbo = zeros(self.runs)
        for i in range(num_particles):
            model_trace, guide_trace = get_importance_trace(
                'flat', max_plate_nesting, model, guide, args, kwargs)
            for site in model_trace.nodes.values():
                if site['name'].startswith('obs'):
                    elbo += site['log_prob'].detach()
                elif site['name'] == 'locs':
                    elbo += site['log_prob'].detach()

            for site in guide_trace.nodes.values():
                if site['name'] == 'locs':
                    elbo -= site['log_prob'].detach()

        return elbo / num_particles
コード例 #3
0
 def _get_trace(self, model, guide, args, kwargs):
     """
     Returns a single trace from the guide, and the model that is run
     against it.
     """
     model_trace, guide_trace = get_importance_trace(
         "flat", self.max_plate_nesting, model, guide, args, kwargs)
     if is_validation_enabled():
         check_if_enumerated(guide_trace)
     return model_trace, guide_trace