def _warn_if_nan(name, value): if torch.is_tensor(value): value = value.item() if torch_isnan(value): warnings.warn("Encountered NAN log_prob_sum at site '{}'".format(name)) if torch_isinf(value) and value > 0: warnings.warn("Encountered +inf log_prob_sum at site '{}'".format(name))
def __call__(self, inputs, s, dim=0, keepdim=False): """Updates the moving average, and returns :code:`inputs.log()`.""" self.n += 1 if torch_isnan(self.ewma) or torch_isinf(self.ewma): ewma = inputs else: ewma = inputs * (1.0 - self.alpha) / ( 1 - self.alpha**self.n) + torch.exp(self.s - s) * self.ewma * ( self.alpha - self.alpha**self.n) / (1 - self.alpha**self.n) self.ewma = ewma.detach() self.s = s.detach() return _ewma_log_fn(inputs, ewma)
def initial_trace(self): """ Find a valid trace to initiate the MCMC sampler. This is also used as a prototype trace to inter-convert between Pyro's trace object and dict object used by the integrator. """ if self._initial_trace: return self._initial_trace trace = poutine.trace(self.model).get_trace(*self._args, **self._kwargs) for i in range(self._max_tries_initial_trace): trace_log_prob_sum = self._compute_trace_log_prob(trace) if not torch_isnan(trace_log_prob_sum) and not torch_isinf( trace_log_prob_sum): self._initial_trace = trace return trace trace = poutine.trace(self.model).get_trace( self._args, self._kwargs) raise ValueError( "Model specification seems incorrect - cannot find a valid trace.")
def _validate_trace(self, trace): trace_log_prob_sum = trace.log_prob_sum() if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum): raise ValueError( "Model specification incorrect - trace log pdf is NaN or Inf.")
def _validate_trace(self, trace): trace_log_prob_sum = trace.log_prob_sum() if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum): raise ValueError("Model specification incorrect - trace log pdf is NaN or Inf.")