Exemplo n.º 1
0
 def _pyro_param(self, msg):
     """
     Overrides the `pyro.param` call with samples sampled from the
     distribution specified in the prior. The prior can be a
     pyro.distributions object or a dict of distributions keyed
     on the param names. If the param name does not match the
     name the keys in the prior, that param name is unchanged.
     """
     name = msg["name"]
     param_name = params.user_param_name(name)
     if isinstance(self.prior, dict):
         # prior is a dict of distributions
         if param_name in self.prior.keys():
             msg["fn"] = self.prior[param_name]
             msg["args"] = msg["args"][1:]
             if isinstance(msg['fn'], Distribution):
                 msg["args"] = ()
                 msg["kwargs"] = {}
                 msg["infer"] = {}
             if is_validation_enabled():
                 self._param_hits.add(param_name)
         else:
             if is_validation_enabled():
                 self._param_misses.add(param_name)
             return None
     elif isinstance(self.prior, Distribution):
         # prior is a distribution
         msg["fn"] = self.prior
         msg["args"] = ()
         msg["kwargs"] = {}
         msg["infer"] = {}
     elif callable(self.prior):
         if not isinstance(self.prior, Distribution):
             # prior is a stochastic fn. block sample
             msg["stop"] = True
         msg["fn"] = self.prior
         msg["args"] = msg["args"][1:]
     else:
         # otherwise leave as is
         return None
     msg["type"] = "sample"
     if name in self._samples_cache:
         # Multiple pyro.param statements with the same
         # name. Block the site and fix the value.
         msg['value'] = self._samples_cache[name]['value']
         msg["is_observed"] = True
         msg["stop"] = True
     else:
         self._samples_cache[name] = msg
         msg["is_observed"] = False
     return self._pyro_sample(msg)
Exemplo n.º 2
0
 def compute_score_parts(self):
     """
     Compute the batched local score parts at each site of the trace.
     Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
     Each ``log_prob_sum`` is a scalar.
     All computations are memoized.
     """
     for name, site in self.nodes.items():
         if site["type"] == "sample" and "score_parts" not in site:
             # Note that ScoreParts overloads the multiplication operator
             # to correctly scale each of its three parts.
             try:
                 value = site["fn"].score_parts(site["value"],
                                                *site["args"],
                                                **site["kwargs"])
             except ValueError:
                 _, exc_value, traceback = sys.exc_info()
                 shapes = self.format_shapes(last_site=site["name"])
                 raise ValueError(
                     "Error while computing score_parts at site '{}':\n{}\n{}"
                     .format(name, exc_value,
                             shapes)).with_traceback(traceback)
             site["unscaled_log_prob"] = value.log_prob
             value = value.scale_and_mask(site["scale"], site["mask"])
             site["score_parts"] = value
             site["log_prob"] = value.log_prob
             site["log_prob_sum"] = value.log_prob.sum()
             if is_validation_enabled():
                 warn_if_nan(site["log_prob_sum"],
                             "log_prob_sum at site '{}'".format(name))
                 warn_if_inf(site["log_prob_sum"],
                             "log_prob_sum at site '{}'".format(name),
                             allow_neginf=True)
Exemplo n.º 3
0
 def compute_log_prob(self, site_filter=lambda name, site: True):
     """
     Compute the site-wise log probabilities of the trace.
     Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
     Each ``log_prob_sum`` is a scalar.
     Both computations are memoized.
     """
     for name, site in self.nodes.items():
         if site["type"] == "sample" and site_filter(name, site):
             if "log_prob" not in site:
                 try:
                     log_p = site["fn"].log_prob(site["value"],
                                                 *site["args"],
                                                 **site["kwargs"])
                 except ValueError:
                     _, exc_value, traceback = sys.exc_info()
                     shapes = self.format_shapes(last_site=site["name"])
                     raise ValueError(
                         "Error while computing log_prob at site '{}':\n{}\n{}"
                         .format(name, exc_value,
                                 shapes)).with_traceback(traceback)
                 site["unscaled_log_prob"] = log_p
                 log_p = scale_and_mask(log_p, site["scale"], site["mask"])
                 site["log_prob"] = log_p
                 site["log_prob_sum"] = log_p.sum()
                 if is_validation_enabled():
                     warn_if_nan(site["log_prob_sum"],
                                 "log_prob_sum at site '{}'".format(name))
                     warn_if_inf(site["log_prob_sum"],
                                 "log_prob_sum at site '{}'".format(name),
                                 allow_neginf=True)
Exemplo n.º 4
0
    def log_prob_sum(self, site_filter=lambda name, site: True):
        """
        Compute the site-wise log probabilities of the trace.
        Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
        Each ``log_prob_sum`` is a scalar.
        The computation of ``log_prob_sum`` is memoized.

        :returns: total log probability.
        :rtype: torch.Tensor
        """
        result = 0.0
        for name, site in self.nodes.items():
            if site["type"] == "sample" and site_filter(name, site):
                if "log_prob_sum" in site:
                    log_p = site["log_prob_sum"]
                else:
                    try:
                        log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
                    except ValueError:
                        _, exc_value, traceback = sys.exc_info()
                        shapes = self.format_shapes(last_site=site["name"])
                        raise ValueError("Error while computing log_prob_sum at site '{}':\n{}\n"
                                         .format(name, exc_value, shapes)).with_traceback(traceback)
                    log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
                    site["log_prob_sum"] = log_p
                    if is_validation_enabled():
                        warn_if_nan(log_p, "log_prob_sum at site '{}'".format(name))
                        warn_if_inf(log_p, "log_prob_sum at site '{}'".format(name), allow_neginf=True)
                result = result + log_p
        return result
Exemplo n.º 5
0
    def _postprocess_message(self, msg):
        if msg["type"] in ("param", "subsample") and self.dim is not None:
            event_dim = msg["kwargs"].get("event_dim")
            if event_dim is not None:
                assert event_dim >= 0
                dim = self.dim - event_dim
                shape = msg["value"].shape
                if len(shape) >= -dim and shape[dim] != 1:
                    if is_validation_enabled() and shape[dim] != self.size:
                        if msg["type"] == "param":
                            statement = "pyro.param({}, ..., event_dim={})".format(
                                msg["name"], event_dim)
                        else:
                            statement = "pyro.subsample(..., event_dim={})".format(
                                event_dim)
                        raise ValueError(
                            "Inside pyro.plate({}, {}, dim={}) invalid shape of {}: {}"
                            .format(self.name, self.size, self.dim, statement,
                                    shape))
                    # Subsample parameters with known batch semantics.
                    if self.subsample_size < self.size:
                        value = msg["value"]
                        new_value = value.index_select(dim, self._indices)
                        if msg["type"] == "param":
                            if hasattr(value, "_pyro_unconstrained_param"):
                                param = value._pyro_unconstrained_param
                            else:
                                param = value.unconstrained()

                            if not hasattr(param, "_pyro_subsample"):
                                param._pyro_subsample = {}

                            param._pyro_subsample[dim] = self._indices
                            new_value._pyro_unconstrained_param = param
                        msg["value"] = new_value
Exemplo n.º 6
0
 def __init__(self, scale):
     if isinstance(scale, torch.Tensor):
         if is_validation_enabled() and not (scale > 0).all():
             raise ValueError("Expected scale > 0 but got {}. ".format(scale) +
                              "Consider using poutine.mask() instead of poutine.scale().")
     elif not (scale > 0):
         raise ValueError("Expected scale > 0 but got {}".format(scale))
     super().__init__()
     self.scale = scale
Exemplo n.º 7
0
 def __exit__(self, *args, **kwargs):
     self._samples_cache = {}
     if is_validation_enabled() and isinstance(self.prior, dict):
         extra = set(self.prior) - self._param_hits
         if extra:
             warnings.warn("pyro.module prior did not find params ['{}']. "
                           "Did you instead mean one of ['{}']?".format(
                               "', '".join(extra),
                               "', '".join(self._param_misses)))
     return super(LiftMessenger, self).__exit__(*args, **kwargs)
Exemplo n.º 8
0
 def _postprocess_message(self, msg):
     if msg["type"] == "param" and self.dim is not None:
         event_dim = msg["kwargs"].get("event_dim")
         if event_dim is not None:
             assert event_dim >= 0
             dim = self.dim - event_dim
             shape = msg["value"].shape
             if len(shape) >= -dim and shape[dim] != 1:
                 if is_validation_enabled() and shape[dim] != self.size:
                     raise ValueError(
                         "Inside pyro.plate({}, {}, dim={}) "
                         "invalid shape of pyro.param({}, ..., event_dim={}): {}"
                         .format(self.name, self.size, self.dim,
                                 msg["name"], event_dim, shape))
                 # Subsample parameters with known batch semantics.
                 if self.subsample_size < self.size:
                     msg["value"] = msg["value"].index_select(
                         dim, self._indices)
Exemplo n.º 9
0
 def __enter__(self):
     self._samples_cache = {}
     if is_validation_enabled() and isinstance(self.prior, dict):
         self._param_hits = set()
         self._param_misses = set()
     return super(LiftMessenger, self).__enter__()