コード例 #1
0
ファイル: trace_struct.py プロジェクト: zippeurfou/pyro
def _warn_if_nan(name, value):
    if torch.is_tensor(value):
        value = value.item()
    if torch_isnan(value):
        warnings.warn("Encountered NAN log_prob_sum at site '{}'".format(name))
    if torch_isinf(value) and value > 0:
        warnings.warn("Encountered +inf log_prob_sum at site '{}'".format(name))
コード例 #2
0
ファイル: eig.py プロジェクト: pyro-ppl/pyro
 def __call__(self, inputs, s, dim=0, keepdim=False):
     """Updates the moving average, and returns :code:`inputs.log()`."""
     self.n += 1
     if torch_isnan(self.ewma) or torch_isinf(self.ewma):
         ewma = inputs
     else:
         ewma = inputs * (1.0 - self.alpha) / (
             1 - self.alpha**self.n) + torch.exp(self.s - s) * self.ewma * (
                 self.alpha - self.alpha**self.n) / (1 - self.alpha**self.n)
     self.ewma = ewma.detach()
     self.s = s.detach()
     return _ewma_log_fn(inputs, ewma)
コード例 #3
0
 def initial_trace(self):
     """
     Find a valid trace to initiate the MCMC sampler. This is also used as a
     prototype trace to inter-convert between Pyro's trace object and dict
     object used by the integrator.
     """
     if self._initial_trace:
         return self._initial_trace
     trace = poutine.trace(self.model).get_trace(*self._args,
                                                 **self._kwargs)
     for i in range(self._max_tries_initial_trace):
         trace_log_prob_sum = self._compute_trace_log_prob(trace)
         if not torch_isnan(trace_log_prob_sum) and not torch_isinf(
                 trace_log_prob_sum):
             self._initial_trace = trace
             return trace
         trace = poutine.trace(self.model).get_trace(
             self._args, self._kwargs)
     raise ValueError(
         "Model specification seems incorrect - cannot find a valid trace.")
コード例 #4
0
ファイル: hmc.py プロジェクト: zippeurfou/pyro
 def _validate_trace(self, trace):
     trace_log_prob_sum = trace.log_prob_sum()
     if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum):
         raise ValueError(
             "Model specification incorrect - trace log pdf is NaN or Inf.")
コード例 #5
0
ファイル: hmc.py プロジェクト: lewisKit/pyro
 def _validate_trace(self, trace):
     trace_log_prob_sum = trace.log_prob_sum()
     if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum):
         raise ValueError("Model specification incorrect - trace log pdf is NaN or Inf.")