예제 #1
0
 def _add_divergence(self, tuning, msg, error=None, point=None):
     if tuning:
         self._divs_tune.append((msg, error, point))
     else:
         self._divs_after_tune.append((msg, error, point))
     if self._on_error == 'raise':
         err = SamplingError('Divergence after tuning: %s Increase '
                             'target_accept or reparameterize.' % msg)
         six.raise_from(err, error)
     elif self._on_error == 'warn':
         warnings.warn('Divergence detected: %s Increase target_accept '
                       'or reparameterize.' % msg)
예제 #2
0
def check_start_vals(start, model):
    r"""Check that the starting values for MCMC do not cause the relevant log probability
    to evaluate to something invalid (e.g. Inf or NaN)

    Parameters
    ----------
    start : dict, or array of dict
        Starting point in parameter space (or partial point)
        Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not
        (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
        overwrite the default.
    model : Model object
    Raises
    ______
    KeyError if the parameters provided by `start` do not agree with the parameters contained
        within `model`
    pymc3.exceptions.SamplingError if the evaluation of the parameters in `start` leads to an
        invalid (i.e. non-finite) state
    Returns
    -------
    None
    """
    start_points = [start] if isinstance(start, dict) else start
    for elem in start_points:
        if not set(elem.keys()).issubset(model.named_vars.keys()):
            extra_keys = ", ".join(
                set(elem.keys()) - set(model.named_vars.keys()))
            valid_keys = ", ".join(model.named_vars.keys())
            raise KeyError(
                "Some start parameters do not appear in the model!\n"
                "Valid keys are: {}, but {} was supplied".format(
                    valid_keys, extra_keys))

        initial_eval = model.check_test_point(test_point=elem)

        if not np.all(np.isfinite(initial_eval)):
            raise SamplingError(
                "Initial evaluation of model at starting point failed!\n"
                "Starting values:\n{}\n\n"
                "Initial evaluation results:\n{}".format(
                    elem, str(initial_eval)))
예제 #3
0
    def astep(self, q0):
        """Perform a single HMC iteration."""
        perf_start = time.perf_counter()
        process_start = time.process_time()

        p0 = self.potential.random()
        start = self.integrator.compute_state(q0, p0)

        if not np.isfinite(start.energy):
            model = self._model
            check_test_point = model.check_test_point()
            error_logp = check_test_point.loc[(
                np.abs(check_test_point) >= 1e20) | np.isnan(check_test_point)]
            self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
            message_energy = (
                "Bad initial energy, check any log probabilities that "
                "are inf or -inf, nan or very small:\n{}".format(
                    error_logp.to_string()))
            warning = SamplerWarning(
                WarningType.BAD_ENERGY,
                message_energy,
                "critical",
                self.iter_count,
            )
            self._warnings.append(warning)
            raise SamplingError("Bad initial energy")

        adapt_step = self.tune and self.adapt_step_size
        step_size = self.step_adapt.current(adapt_step)
        self.step_size = step_size

        if self._step_rand is not None:
            step_size = self._step_rand(step_size)

        hmc_step = self._hamiltonian_step(start, p0, step_size)

        perf_end = time.perf_counter()
        process_end = time.process_time()

        self.step_adapt.update(hmc_step.accept_stat, adapt_step)
        self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
        if hmc_step.divergence_info:
            info = hmc_step.divergence_info
            point = None
            point_dest = None
            info_store = None
            if self.tune:
                kind = WarningType.TUNING_DIVERGENCE
            else:
                kind = WarningType.DIVERGENCE
                self._num_divs_sample += 1
                # We don't want to fill up all memory with divergence info
                if self._num_divs_sample < 100 and info.state is not None:
                    point = self._logp_dlogp_func.array_to_dict(info.state.q)
                if self._num_divs_sample < 100 and info.state_div is not None:
                    point_dest = self._logp_dlogp_func.array_to_dict(
                        info.state_div.q)
                if self._num_divs_sample < 100:
                    info_store = info
            warning = SamplerWarning(
                kind,
                info.message,
                "debug",
                self.iter_count,
                info.exec_info,
                divergence_point_source=point,
                divergence_point_dest=point_dest,
                divergence_info=info_store,
            )

            self._warnings.append(warning)

        self.iter_count += 1
        if not self.tune:
            self._samples_after_tune += 1

        stats = {
            "tune": self.tune,
            "diverging": bool(hmc_step.divergence_info),
            "perf_counter_diff": perf_end - perf_start,
            "process_time_diff": process_end - process_start,
            "perf_counter_start": perf_start,
        }

        stats.update(hmc_step.stats)
        stats.update(self.step_adapt.stats())

        return hmc_step.end.q, [stats]