Exemple #1
0
    def warnings(self, strace):
        # list.copy() is not available in python2
        warnings = self._warnings[:]

        # Generate a global warning for divergences
        n_divs = self._num_divs_sample
        if n_divs and self._samples_after_tune == n_divs:
            msg = ('The chain contains only diverging samples. The model is '
                   'probably misspecified.')
            warning = SamplerWarning(WarningType.DIVERGENCES, msg, 'error',
                                     None, None, None)
            warnings.append(warning)
        elif n_divs > 0:
            message = ('Divergences after tuning. Increase `target_accept` or '
                       'reparameterize.')
            warning = SamplerWarning(WarningType.DIVERGENCES, message, 'error',
                                     None, None, None)
            warnings.append(warning)

        # small trace
        if self._samples_after_tune == 0:
            msg = "Tuning was enabled throughout the whole trace."
            warning = SamplerWarning(WarningType.BAD_PARAMS, msg, 'error',
                                     None, None, None)
            warnings.append(warning)
        elif self._samples_after_tune < 500:
            msg = "Only %s samples in chain." % self._samples_after_tune
            warning = SamplerWarning(WarningType.BAD_PARAMS, msg, 'error',
                                     None, None, None)
            warnings.append(warning)

        warnings.extend(self.step_adapt.warnings())
        return warnings
Exemple #2
0
    def warnings(self):
        # list.copy() is not available in python2
        warnings = self._warnings[:]

        # Generate a global warning for divergences
        message = ""
        n_divs = self._num_divs_sample
        if n_divs and self._samples_after_tune == n_divs:
            message = (
                "The chain contains only diverging samples. The model " "is probably misspecified."
            )
        elif n_divs == 1:
            message = (
                "There was 1 divergence after tuning. Increase "
                "`target_accept` or reparameterize."
            )
        elif n_divs > 1:
            message = (
                "There were %s divergences after tuning. Increase "
                "`target_accept` or reparameterize." % n_divs
            )

        if message:
            warning = SamplerWarning(WarningType.DIVERGENCES, message, "error")
            warnings.append(warning)

        warnings.extend(self.step_adapt.warnings())
        return warnings
Exemple #3
0
    def astep(self, q0):
        """Perform a single HMC iteration."""
        p0 = self.potential.random()
        start = self.integrator.compute_state(q0, p0)
        model = self._model

        if not np.isfinite(start.energy):
            check_test_point = model.check_test_point()
            error_logp = check_test_point.loc[(
                np.abs(check_test_point) >= 1e20) | np.isnan(check_test_point)]
            self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
            logger.error(
                "Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:\n{}"
                .format(error_logp.to_string()))
            raise ValueError('Bad initial energy')

        adapt_step = self.tune and self.adapt_step_size
        step_size = self.step_adapt.current(adapt_step)
        self.step_size = step_size

        if self._step_rand is not None:
            step_size = self._step_rand(step_size)

        hmc_step = self._hamiltonian_step(start, p0, step_size)

        self.step_adapt.update(hmc_step.accept_stat, adapt_step)
        self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
        if hmc_step.divergence_info:
            info = hmc_step.divergence_info
            if self.tune:
                kind = WarningType.TUNING_DIVERGENCE
                point = None
            else:
                kind = WarningType.DIVERGENCE
                self._num_divs_sample += 1
                # We don't want to fill up all memory with divergence info
                if self._num_divs_sample < 100:
                    point = self._logp_dlogp_func.array_to_dict(info.state.q)
                else:
                    point = None
            warning = SamplerWarning(kind, info.message, 'debug',
                                     self.iter_count, info.exec_info, point)

            self._warnings.append(warning)

        self.iter_count += 1
        if not self.tune:
            self._samples_after_tune += 1

        stats = {
            'tune': self.tune,
            'diverging': bool(hmc_step.divergence_info),
        }

        stats.update(hmc_step.stats)
        stats.update(self.step_adapt.stats())

        return hmc_step.end.q, [stats]
Exemple #4
0
    def warnings(self, strace):
        warnings = super(NUTS, self).warnings(strace)

        if np.mean(self._reached_max_treedepth) > 0.05:
            msg = ('The chain reached the maximum tree depth. Increase '
                   'max_treedepth, increase target_accept or reparameterize.')
            warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn', None,
                                  None, None)
            warnings.append(warn)
        return warnings
Exemple #5
0
    def warnings(self):
        warnings = super().warnings()
        n_samples = self._samples_after_tune
        n_treedepth = self._reached_max_treedepth

        if n_samples > 0 and n_treedepth / float(n_samples) > 0.05:
            msg = ("The chain reached the maximum tree depth. Increase "
                   "max_treedepth, increase target_accept or reparameterize.")
            warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn")
            warnings.append(warn)
        return warnings
Exemple #6
0
    def warnings(self, strace):
        # list.copy() is not available in python2
        warnings = self._warnings[:]

        # Generate a global warning for divergences
        n_divs = self._num_divs_sample
        if n_divs and self._samples_after_tune == n_divs:
            msg = ('The chain contains only diverging samples. The model is '
                   'probably misspecified.')
            warning = SamplerWarning(WarningType.DIVERGENCES, msg, 'error',
                                     None, None, None)
            warnings.append(warning)
        elif n_divs > 0:
            message = ('There were %s divergences after tuning. Increase '
                       '`target_accept` or reparameterize.' % n_divs)
            warning = SamplerWarning(WarningType.DIVERGENCES, message, 'error',
                                     None, None, None)
            warnings.append(warning)

        warnings.extend(self.step_adapt.warnings())
        return warnings
Exemple #7
0
    def warnings(self, strace):
        warnings = super(NUTS, self).warnings(strace)
        n_samples = self._samples_after_tune
        n_treedepth = self._reached_max_treedepth

        if n_samples > 0 and n_treedepth / float(n_samples) > 0.05:
            msg = ('The chain reached the maximum tree depth. Increase '
                   'max_treedepth, increase target_accept or reparameterize.')
            warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn',
                                  None, None, None)
            warnings.append(warn)
        return warnings
Exemple #8
0
    def astep(self, q0):
        """Perform a single HMC iteration."""
        p0 = self.potential.random()
        start = self.integrator.compute_state(q0, p0)

        if not np.isfinite(start.energy):
            self.potential.raise_ok()
            raise ValueError('Bad initial energy: %s. The model '
                             'might be misspecified.' % start.energy)

        adapt_step = self.tune and self.adapt_step_size
        step_size = self.step_adapt.current(adapt_step)
        self.step_size = step_size

        if self._step_rand is not None:
            step_size = self._step_rand(step_size)

        hmc_step = self._hamiltonian_step(start, p0, step_size)

        self.step_adapt.update(hmc_step.accept_stat, adapt_step)
        self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
        if hmc_step.divergence_info:
            info = hmc_step.divergence_info
            if self.tune:
                kind = WarningType.TUNING_DIVERGENCE
                point = None
            else:
                kind = WarningType.DIVERGENCE
                self._num_divs_sample += 1
                # We don't want to fill up all memory with divergence info
                if self._num_divs_sample < 100:
                    point = self._logp_dlogp_func.array_to_dict(info.state.q)
                else:
                    point = None
            warning = SamplerWarning(kind, info.message, 'debug',
                                     self.iter_count, info.exec_info, point)

            self._warnings.append(warning)

        self.iter_count += 1
        if not self.tune:
            self._samples_after_tune += 1

        stats = {
            'tune': self.tune,
            'diverging': bool(hmc_step.divergence_info),
        }

        stats.update(hmc_step.stats)
        stats.update(self.step_adapt.stats())

        return hmc_step.end.q, [stats]
Exemple #9
0
 def warnings(self):
     accept = np.array(self._tuned_stats)
     mean_accept = np.mean(accept)
     target_accept = self._target
     # Try to find a reasonable interval for acceptable acceptance
     # probabilities. Finding this was mostry trial and error.
     n_bound = min(100, len(accept))
     n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound
     lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95)
     if target_accept < lower or target_accept > upper:
         msg = ('The acceptance probability does not match the target. It '
                'is %s, but should be close to %s. Try to increase the '
                'number of tuning steps.' % (mean_accept, target_accept))
         info = {'target': target_accept, 'actual': mean_accept}
         warning = SamplerWarning(WarningType.BAD_ACCEPTANCE, msg, 'warn',
                                  None, None, info)
         return [warning]
     else:
         return []
Exemple #10
0
    def astep(self, q0):
        """Perform a single HMC iteration."""
        perf_start = time.perf_counter()
        process_start = time.process_time()

        p0 = self.potential.random()
        start = self.integrator.compute_state(q0, p0)

        if not np.isfinite(start.energy):
            model = self._model
            check_test_point = model.check_test_point()
            error_logp = check_test_point.loc[(
                np.abs(check_test_point) >= 1e20) | np.isnan(check_test_point)]
            self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
            message_energy = (
                "Bad initial energy, check any log probabilities that "
                "are inf or -inf, nan or very small:\n{}".format(
                    error_logp.to_string()))
            warning = SamplerWarning(
                WarningType.BAD_ENERGY,
                message_energy,
                "critical",
                self.iter_count,
            )
            self._warnings.append(warning)
            raise SamplingError("Bad initial energy")

        adapt_step = self.tune and self.adapt_step_size
        step_size = self.step_adapt.current(adapt_step)
        self.step_size = step_size

        if self._step_rand is not None:
            step_size = self._step_rand(step_size)

        hmc_step = self._hamiltonian_step(start, p0, step_size)

        perf_end = time.perf_counter()
        process_end = time.process_time()

        self.step_adapt.update(hmc_step.accept_stat, adapt_step)
        self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
        if hmc_step.divergence_info:
            info = hmc_step.divergence_info
            point = None
            point_dest = None
            info_store = None
            if self.tune:
                kind = WarningType.TUNING_DIVERGENCE
            else:
                kind = WarningType.DIVERGENCE
                self._num_divs_sample += 1
                # We don't want to fill up all memory with divergence info
                if self._num_divs_sample < 100 and info.state is not None:
                    point = self._logp_dlogp_func.array_to_dict(info.state.q)
                if self._num_divs_sample < 100 and info.state_div is not None:
                    point_dest = self._logp_dlogp_func.array_to_dict(
                        info.state_div.q)
                if self._num_divs_sample < 100:
                    info_store = info
            warning = SamplerWarning(
                kind,
                info.message,
                "debug",
                self.iter_count,
                info.exec_info,
                divergence_point_source=point,
                divergence_point_dest=point_dest,
                divergence_info=info_store,
            )

            self._warnings.append(warning)

        self.iter_count += 1
        if not self.tune:
            self._samples_after_tune += 1

        stats = {
            "tune": self.tune,
            "diverging": bool(hmc_step.divergence_info),
            "perf_counter_diff": perf_end - perf_start,
            "process_time_diff": process_end - process_start,
            "perf_counter_start": perf_start,
        }

        stats.update(hmc_step.stats)
        stats.update(self.step_adapt.stats())

        return hmc_step.end.q, [stats]