Exemple #1
0
 def _prepare_site(self, msg):
     """
     Sets flags of params that will be overridden so they are not
     reexecuted in the stack and not added to the param store.
     """
     name = msg["name"]
     param_name = params.user_param_name(name)
     if isinstance(self.prior, dict) and param_name in self.prior.keys() \
             or callable(self.prior):
         if msg["type"] == "param":
             msg["done"] = True
     return msg
Exemple #2
0
 def _pyro_param(self, msg):
     """
     Overrides the `pyro.param` call with samples sampled from the
     distribution specified in the prior. The prior can be a
     pyro.distributions object or a dict of distributions keyed
     on the param names. If the param name does not match the
     name the keys in the prior, that param name is unchanged.
     """
     name = msg["name"]
     param_name = params.user_param_name(name)
     if isinstance(self.prior, dict):
         # prior is a dict of distributions
         if param_name in self.prior.keys():
             msg["fn"] = self.prior[param_name]
             msg["args"] = msg["args"][1:]
             if isinstance(msg['fn'], Distribution):
                 msg["args"] = ()
                 msg["kwargs"] = {}
                 msg["infer"] = {}
             if is_validation_enabled():
                 self._param_hits.add(param_name)
         else:
             if is_validation_enabled():
                 self._param_misses.add(param_name)
             return None
     elif isinstance(self.prior, Distribution):
         # prior is a distribution
         msg["fn"] = self.prior
         msg["args"] = ()
         msg["kwargs"] = {}
         msg["infer"] = {}
     elif callable(self.prior):
         if not isinstance(self.prior, Distribution):
             # prior is a stochastic fn. block sample
             msg["stop"] = True
         msg["fn"] = self.prior
         msg["args"] = msg["args"][1:]
     else:
         # otherwise leave as is
         return None
     msg["type"] = "sample"
     if name in self._samples_cache:
         # Multiple pyro.param statements with the same
         # name. Block the site and fix the value.
         msg['value'] = self._samples_cache[name]['value']
         msg["is_observed"] = True
         msg["stop"] = True
     else:
         self._samples_cache[name] = msg
         msg["is_observed"] = False
     return self._pyro_sample(msg)
Exemple #3
0
    def _get_optim_args(self, param):
        # if we were passed a fct, we call fct with param info
        # arguments are (module name, param name) e.g. ('mymodule', 'bias')
        if callable(self.pt_optim_args):

            # get param name
            param_name = pyro.get_param_store().param_name(param)
            module_name = module_from_param_with_module_name(param_name)
            stripped_param_name = user_param_name(param_name)

            # invoke the user-provided callable
            opt_dict = self.pt_optim_args(module_name, stripped_param_name)

            # must be dictionary
            assert isinstance(opt_dict, dict), "per-param optim arg must return defaults dictionary"
            return opt_dict
        else:
            return self.pt_optim_args
Exemple #4
0
 def _pyro_param(self, msg):
     """
     Overrides the `pyro.param` call with samples sampled from the
     distribution specified in the prior. The prior can be a
     pyro.distributions object or a dict of distributions keyed
     on the param names. If the param name does not match the
     name the keys in the prior, that param name is unchanged.
     """
     name = msg["name"]
     param_name = params.user_param_name(name)
     if isinstance(self.prior, dict):
         # prior is a dict of distributions
         if param_name in self.prior.keys():
             msg["fn"] = self.prior[param_name]
             if isinstance(msg['fn'], Distribution):
                 msg["args"] = ()
                 msg["kwargs"] = {}
                 msg["baseline"] = {}
         else:
             return super(LiftPoutine, self)._pyro_param(msg)
     elif isinstance(self.prior, Distribution):
         # prior is a distribution
         msg["fn"] = self.prior
         msg["args"] = ()
         msg["kwargs"] = {}
         msg["baseline"] = {}
     elif callable(self.prior):
         if not isinstance(self.prior, Distribution):
             # prior is a stochastic fn. block sample
             msg["stop"] = True
         msg["fn"] = self.prior
     else:
         # otherwise leave as is
         return super(LiftPoutine, self)._pyro_param(msg)
     msg["type"] = "sample"
     msg["done"] = False
     msg["is_observed"] = False
     return self._pyro_sample(msg)