def warnings(self): # list.copy() is not available in python2 warnings = self._warnings[:] # Generate a global warning for divergences message = "" n_divs = self._num_divs_sample if n_divs and self._samples_after_tune == n_divs: message = ( "The chain contains only diverging samples. The model " "is probably misspecified." ) elif n_divs == 1: message = ( "There was 1 divergence after tuning. Increase " "`target_accept` or reparameterize." ) elif n_divs > 1: message = ( "There were %s divergences after tuning. Increase " "`target_accept` or reparameterize." % n_divs ) if message: warning = SamplerWarning(WarningType.DIVERGENCES, message, "error") warnings.append(warning) warnings.extend(self.step_adapt.warnings()) return warnings
def warnings(self): warnings = super().warnings() n_samples = self._samples_after_tune n_treedepth = self._reached_max_treedepth if n_samples > 0 and n_treedepth / float(n_samples) > 0.05: msg = ("The chain reached the maximum tree depth. Increase " "max_treedepth, increase target_accept or reparameterize.") warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn") warnings.append(warn) return warnings
def warnings(self): accept = np.array(self._tuned_stats) mean_accept = np.mean(accept) target_accept = self._target # Try to find a reasonable interval for acceptable acceptance # probabilities. Finding this was mostly trial and error. n_bound = min(100, len(accept)) n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95) if target_accept < lower or target_accept > upper: msg = ( f"The acceptance probability does not match the target. " f"It is {mean_accept:0.4g}, but should be close to {target_accept:0.4g}. " f"Try to increase the number of tuning steps." ) info = {"target": target_accept, "actual": mean_accept} warning = SamplerWarning(WarningType.BAD_ACCEPTANCE, msg, "warn", extra=info) return [warning] else: return []
def astep(self, q0): """Perform a single HMC iteration.""" perf_start = time.perf_counter() process_start = time.process_time() p0 = self.potential.random() p0 = RaveledVars(p0, q0.point_map_info) start = self.integrator.compute_state(q0, p0) if not np.isfinite(start.energy): model = self._model check_test_point = model.point_logps() error_logp = check_test_point.loc[ (np.abs(check_test_point) >= 1e20) | np.isnan(check_test_point) ] self.potential.raise_ok(q0.point_map_info) message_energy = ( "Bad initial energy, check any log probabilities that " "are inf or -inf, nan or very small:\n{}".format(error_logp.to_string()) ) warning = SamplerWarning( WarningType.BAD_ENERGY, message_energy, "critical", self.iter_count, ) self._warnings.append(warning) raise SamplingError("Bad initial energy") adapt_step = self.tune and self.adapt_step_size step_size = self.step_adapt.current(adapt_step) self.step_size = step_size if self._step_rand is not None: step_size = self._step_rand(step_size) hmc_step = self._hamiltonian_step(start, p0.data, step_size) perf_end = time.perf_counter() process_end = time.process_time() self.step_adapt.update(hmc_step.accept_stat, adapt_step) self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune) if hmc_step.divergence_info: info = hmc_step.divergence_info point = None point_dest = None info_store = None if self.tune: kind = WarningType.TUNING_DIVERGENCE else: kind = WarningType.DIVERGENCE self._num_divs_sample += 1 # We don't want to fill up all memory with divergence info if self._num_divs_sample < 100 and info.state is not None: point = DictToArrayBijection.rmap(info.state.q) if self._num_divs_sample < 100 and info.state_div is not None: point = DictToArrayBijection.rmap(info.state_div.q) if self._num_divs_sample < 100: info_store = info warning = SamplerWarning( kind, info.message, "debug", self.iter_count, info.exec_info, divergence_point_source=point, divergence_point_dest=point_dest, divergence_info=info_store, ) self._warnings.append(warning) self.iter_count += 1 if not self.tune: self._samples_after_tune += 1 stats = { "tune": self.tune, "diverging": bool(hmc_step.divergence_info), "perf_counter_diff": perf_end - perf_start, "process_time_diff": process_end - process_start, "perf_counter_start": perf_start, } stats.update(hmc_step.stats) stats.update(self.step_adapt.stats()) return hmc_step.end.q, [stats]