def _add_divergence(self, tuning, msg, error=None, point=None): if tuning: self._divs_tune.append((msg, error, point)) else: self._divs_after_tune.append((msg, error, point)) if self._on_error == 'raise': err = SamplingError('Divergence after tuning: %s Increase ' 'target_accept or reparameterize.' % msg) six.raise_from(err, error) elif self._on_error == 'warn': warnings.warn('Divergence detected: %s Increase target_accept ' 'or reparameterize.' % msg)
def check_start_vals(start, model): r"""Check that the starting values for MCMC do not cause the relevant log probability to evaluate to something invalid (e.g. Inf or NaN) Parameters ---------- start : dict, or array of dict Starting point in parameter space (or partial point) Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can overwrite the default. model : Model object Raises ______ KeyError if the parameters provided by `start` do not agree with the parameters contained within `model` pymc3.exceptions.SamplingError if the evaluation of the parameters in `start` leads to an invalid (i.e. non-finite) state Returns ------- None """ start_points = [start] if isinstance(start, dict) else start for elem in start_points: if not set(elem.keys()).issubset(model.named_vars.keys()): extra_keys = ", ".join( set(elem.keys()) - set(model.named_vars.keys())) valid_keys = ", ".join(model.named_vars.keys()) raise KeyError( "Some start parameters do not appear in the model!\n" "Valid keys are: {}, but {} was supplied".format( valid_keys, extra_keys)) initial_eval = model.check_test_point(test_point=elem) if not np.all(np.isfinite(initial_eval)): raise SamplingError( "Initial evaluation of model at starting point failed!\n" "Starting values:\n{}\n\n" "Initial evaluation results:\n{}".format( elem, str(initial_eval)))
def astep(self, q0): """Perform a single HMC iteration.""" perf_start = time.perf_counter() process_start = time.process_time() p0 = self.potential.random() start = self.integrator.compute_state(q0, p0) if not np.isfinite(start.energy): model = self._model check_test_point = model.check_test_point() error_logp = check_test_point.loc[( np.abs(check_test_point) >= 1e20) | np.isnan(check_test_point)] self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap) message_energy = ( "Bad initial energy, check any log probabilities that " "are inf or -inf, nan or very small:\n{}".format( error_logp.to_string())) warning = SamplerWarning( WarningType.BAD_ENERGY, message_energy, "critical", self.iter_count, ) self._warnings.append(warning) raise SamplingError("Bad initial energy") adapt_step = self.tune and self.adapt_step_size step_size = self.step_adapt.current(adapt_step) self.step_size = step_size if self._step_rand is not None: step_size = self._step_rand(step_size) hmc_step = self._hamiltonian_step(start, p0, step_size) perf_end = time.perf_counter() process_end = time.process_time() self.step_adapt.update(hmc_step.accept_stat, adapt_step) self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune) if hmc_step.divergence_info: info = hmc_step.divergence_info point = None point_dest = None info_store = None if self.tune: kind = WarningType.TUNING_DIVERGENCE else: kind = WarningType.DIVERGENCE self._num_divs_sample += 1 # We don't want to fill up all memory with divergence info if self._num_divs_sample < 100 and info.state is not None: point = self._logp_dlogp_func.array_to_dict(info.state.q) if self._num_divs_sample < 100 and info.state_div is not None: point_dest = self._logp_dlogp_func.array_to_dict( info.state_div.q) if self._num_divs_sample < 100: info_store = info warning = SamplerWarning( kind, info.message, "debug", self.iter_count, info.exec_info, divergence_point_source=point, divergence_point_dest=point_dest, divergence_info=info_store, ) self._warnings.append(warning) self.iter_count += 1 if not self.tune: self._samples_after_tune += 1 stats = { "tune": self.tune, "diverging": bool(hmc_step.divergence_info), "perf_counter_diff": perf_end - perf_start, "process_time_diff": process_end - process_start, "perf_counter_start": perf_start, } stats.update(hmc_step.stats) stats.update(self.step_adapt.stats()) return hmc_step.end.q, [stats]