def test_precision(self):
        itp_x2 = WaryInterpolator(points=(0, 1, 2), values=(0, 1, 4), precision=0.01, domain=(0, 5))
        self.assertRaises(InsufficientPrecisionError, itp_x2, 0.5)

        xs = np.array([0.485, 0.495, 0.505, 0.515])
        itp_x2.add_points(xs, xs**2)

        self.assertAlmostEqual(itp_x2(0.5), (0.495**2 + 0.505**2)/2)
    def test_interpolate(self):
        itp = WaryInterpolator(if_higher='extrapolate', if_lower='raise', domain=(0, 5))

        self.assertRaises(InsufficientPrecisionError, itp, 0)

        itp.add_points([0, 2, 4], [10, 20, 30])
        self.assertEqual(itp(0), 10)
        self.assertAlmostEqual(itp(1), 15)
        self.assertAlmostEqual(itp(5), 35)
        self.assertRaises(OutsideDomainError, itp, -5)
    def test_add_points(self):
        itp1 = WaryInterpolator(domain=[0, 5])
        itp1.add_points([0, 4, 2], [10, 30, 20])

        itp2 = WaryInterpolator(points=(0, 2, 4), values=(10, 20, 30), domain=[0, 5])

        np.testing.assert_array_equal(itp1.points, itp2.points)
        np.testing.assert_array_equal(itp1.values, itp2.values)

        itp1.add_point(4, 40)
        itp2.add_point(4, 40)

        np.testing.assert_array_equal(itp1.points, itp2.points)
        np.testing.assert_array_equal(itp1.values, itp2.values)
示例#4
0
    def __init__(self, statistic, **kwargs):
        self.statistic = statistic
        for k, v in kwargs.items():
            setattr(self, k, v)
        self.cl = self.confidence_level

        self.log = logging.getLogger(self.__class__.__name__)

        if self.wrap_interpolator:
            self.log.debug("Initializing interpolators")
            if self.fixed_lower_limit is None:
                self.low_limit_interpolator = WaryInterpolator(precision=10**(-self.precision_digits),
                                                               domain=self.interpolator_log_domain)
            if self.fixed_upper_limit is None:
                self.high_limit_interpolator = WaryInterpolator(precision=10**(-self.precision_digits),
                                                                domain=self.interpolator_log_domain)
            # "Joints" of the interpolator must have better precision than required of the interpolator results
            self.precision_digits += 1

        # Dictionary holding "horizontal" intervals: interval on statistic for each precision and hypothesis.
        self.cached_intervals = {}
示例#5
0
class IntervalChoice(object):
    """Base interval choice method class
    """
    method = 'rank'    # 'rank' or 'threshold'
    threshold = float('inf')
    precision_digits = 2
    use_interval_cache = True
    wrap_interpolator = True
    background = 0
    confidence_level = 0.9
    max_hypothesis = 1e6
    interpolator_log_domain = (-1, 3)
    fixed_upper_limit = None
    fixed_lower_limit = None
    # Use only for testing:
    forbid_exact_computation = False

    def __init__(self, statistic, **kwargs):
        self.statistic = statistic
        for k, v in kwargs.items():
            setattr(self, k, v)
        self.cl = self.confidence_level

        self.log = logging.getLogger(self.__class__.__name__)

        if self.wrap_interpolator:
            self.log.debug("Initializing interpolators")
            if self.fixed_lower_limit is None:
                self.low_limit_interpolator = WaryInterpolator(precision=10**(-self.precision_digits),
                                                               domain=self.interpolator_log_domain)
            if self.fixed_upper_limit is None:
                self.high_limit_interpolator = WaryInterpolator(precision=10**(-self.precision_digits),
                                                                domain=self.interpolator_log_domain)
            # "Joints" of the interpolator must have better precision than required of the interpolator results
            self.precision_digits += 1

        # Dictionary holding "horizontal" intervals: interval on statistic for each precision and hypothesis.
        self.cached_intervals = {}

    def get_interval_on_statistic(self, hypothesis, precision_digits):
        """Returns the self.cl confidence level interval on self.statistic for the event rate hypothesis
        The event rate here includes signal as well as identically distributed background.
        Intervals are inclusive = closed.
        """
        if self.use_interval_cache and (hypothesis, precision_digits) in self.cached_intervals:
            return self.cached_intervals[(hypothesis, precision_digits)]

        stat_values, likelihoods = self.statistic.get_values_and_likelihoods(hypothesis,
                                                                             precision_digits=precision_digits)
        likelihoods = likelihoods / np.sum(likelihoods)

        # Score each statistic value (method-dependent)
        stat_value_scores = self.score_stat_values(statistic_values=stat_values,
                                                   likelihoods=likelihoods,
                                                   hypothesis=hypothesis)
        if self.method == 'threshold':
            # Include all statistic values that score higher than some threshold
            values_in_interval = stat_values[stat_value_scores > self.get_threshold()]

        else:
            # Include the values with highest score first, until we reach the desired confidence level
            # TODO: wouldn't HIGHEST score first be more user-friendly?
            ranks = np.argsort(stat_value_scores)
            train_values_sorted = stat_values[ranks]
            likelihoods_sorted = likelihoods[ranks]

            # Find the last value to include
            # (= first value that takes the included probability over the required confidence level)
            sum_lhoods = np.cumsum(likelihoods_sorted)
            last_index = np.where(sum_lhoods > self.cl)[0][0]   # TODO: can fail?
            values_in_interval = train_values_sorted[:last_index + 1]

        # Limits = extreme values in the interval.
        # This means we will be conservative if values_in_interval is not continuous.
        low_lim, high_lim = values_in_interval.min(), values_in_interval.max()

        # If we included all values given up until a boundary, don't set that boundary as a limit
        if low_lim == np.min(stat_values):
            low_lim = 0
        if high_lim == np.max(stat_values):
            high_lim = float('inf')

        # Cache and return upper and lower limit on the statistic
        if self.use_interval_cache:
            self.cached_intervals[(hypothesis, precision_digits)] = low_lim, high_lim
        return low_lim, high_lim

    def get_confidence_interval(self, value, precision_digits, search_region, debug=False):
        """Performs the Neynman construction to get confidence interval on event rate (mu),
        if the statistic is observed to have value
        """
        log_value = np.log10(value)
        if self.wrap_interpolator:
            # Try to interpolate the limit from limits computed earlier
            self.log.debug("Trying to get values from interpolators")
            try:
                if self.fixed_lower_limit is None:
                    low_limit = 10**(self.low_limit_interpolator(log_value))
                else:
                    low_limit = self.fixed_lower_limit
                if self.fixed_upper_limit is None:
                    high_limit = 10**(self.high_limit_interpolator(log_value))
                else:
                    high_limit = self.fixed_upper_limit
                return low_limit, high_limit
            except InsufficientPrecisionError:
                self.log.debug("Insuffienct precision achieved by interpolators")
                if log_value > self.interpolator_log_domain[1]:
                    self.log.debug("Too high value to dare to start Neyman construction... raising exception")
                    # It is not safe to do the Neyman construction: too high statistics
                    raise
                self.log.debug("Log value %s is below interpolator log domain max %s "
                               "=> starting Neyman construction" % (log_value, self.interpolator_log_domain[1]))
            except OutsideDomainError:
                # The value is below the interpolator  domain (e.g. 0 while the domain ends at 10**0 = 1)
                pass

        if self.forbid_exact_computation:
            raise RuntimeError("Exact computation triggered")

        def is_value_in(mu):
            low_lim, high_lim = self.get_interval_on_statistic(mu + self.background,
                                                               precision_digits=precision_digits)
            return low_lim <= value <= high_lim

        # We first need one value in the interval to bound the limit searches
        try:
            true_point, low_search_bound, high_search_bound = search_true_instance(is_value_in,
                                                                                   *search_region,
                                                                                   precision_digits=precision_digits)
        except SearchFailedException as e:
            self.log.debug("Exploratory search could not find a single value in the interval! "
                           "This is probably a problem with search region, or simply a very extreme case."
                           "Original exception: %s" % str(e))
            if is_value_in(0):
                self.log.debug("Oh, ok, only zero is in the interval... Returning (0, 0)")
                return 0, 0
            return 0, float('inf')

        self.log.debug(">>> Exploratory search completed: %s is in interval, "
                       "search for boundaries in [%s, %s]" % (true_point, low_search_bound, high_search_bound))

        if self.fixed_lower_limit is not None:
            low_limit = self.fixed_lower_limit
        elif is_value_in(low_search_bound):
            # If mu=0 can't be excluded, we're apparently only setting an upper limit (mu <= ..)
            low_limit = 0
        else:
            low_limit = bisect_search(is_value_in, low_search_bound, true_point, precision_digits=precision_digits)
        self.log.debug(">>> Low limit found at %s" % low_limit)

        if self.fixed_upper_limit is not None:
            low_limit = self.fixed_upper_limit
        elif is_value_in(high_search_bound):
            # If max_mu can't be excluded, we're apparently only setting a lower limit (mu >= ..)
            high_limit = float('inf')
        else:
            high_limit = bisect_search(is_value_in, true_point, high_search_bound, precision_digits=precision_digits)
        self.log.debug(">>> High limit found at %s" % high_limit)

        if self.wrap_interpolator:
            # Add the values to the interpolator, if they are within the domain
            # TODO: Think about dealing with inf
            if self.interpolator_log_domain[0] <= log_value <= self.interpolator_log_domain[1]:
                if self.fixed_lower_limit is None:
                    self.low_limit_interpolator.add_point(log_value, np.log10(low_limit))
                if self.fixed_upper_limit is None:
                    self.high_limit_interpolator.add_point(log_value, np.log10(high_limit))

        return low_limit, high_limit

    def score_stat_values(self, **kwargs):
        # Return "rank" of each hypothesis. Hypotheses with highest ranks will be included first.
        raise NotImplementedError()

    def __call__(self, observation, precision_digits=None, search_region=None):
        """Perform Neynman construction to get confidence interval on event rate for observation.
        """
        if precision_digits is None:
            precision_digits = self.precision_digits
        if search_region is None:
            search_region = [0, round_to_digits(10 + 3 * len(observation), precision_digits)]
        if self.statistic.mu_dependent:
            value = self.statistic(observation, self.statistic.mus)
        else:
            value = self.statistic(observation, None)
        self.log.debug("Statistic evaluates to %s" % value)
        return self.get_confidence_interval(value, precision_digits=precision_digits, search_region=search_region)