Esempio n. 1
0
def _remove_outliers_from_hist(
        hist: Hist, outliers_start_index: int,
        outliers_removal_axis: OutliersRemovalAxis) -> None:
    """Remove outliers from a given histogram.

    Args:
        hist: Histogram to check for outliers.
        outliers_start_index: Index in the truth axis where outliers begin.
        outliers_removal_axis: Axis along which outliers removal will be performed. Usually
            the particle level aixs.
    Returns:
        None. The histogram is modified in place.
    """
    # Use on TH1, TH2, and TH3 since we don't start removing immediately, but instead only after the limit
    if outliers_start_index > 0:
        # logger.debug("Removing outliers")
        # Check for values above which they should be removed by translating the global index
        x = ctypes.c_int(0)
        y = ctypes.c_int(0)
        z = ctypes.c_int(0)
        # Maps axis to valaues
        # This is kind of dumb, but it works.
        outliers_removal_axis_values: Dict[OutliersRemovalAxis,
                                           ctypes.c_int] = {
                                               projectors.TH1AxisType.x_axis:
                                               x,
                                               projectors.TH1AxisType.y_axis:
                                               y,
                                               projectors.TH1AxisType.z_axis:
                                               z,
                                           }
        for index in range(0, hist.GetNcells()):
            # Get the bin x, y, z from the global bin
            hist.GetBinXYZ(index, x, y, z)
            # Watch out for any problems
            if hist.GetBinContent(index) < hist.GetBinError(index):
                logger.warning(
                    f"Bin content < error. Name: {hist.GetName()}, Bin content: {hist.GetBinContent(index)}, Bin error: {hist.GetBinError(index)}, index: {index}, ({x.value}, {y.value})"
                )
            if outliers_removal_axis_values[
                    outliers_removal_axis].value >= outliers_start_index:
                # logger.debug("Cutting for index {}. x bin {}. Cut index: {}".format(index, x, cutIndex))
                hist.SetBinContent(index, 0)
                hist.SetBinError(index, 0)
    else:
        logger.info(f"Hist {hist.GetName()} did not have any outliers to cut")
Esempio n. 2
0
    def _from_th1(
        hist: Hist
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]:
        """ Convert a TH1 histogram to a Histogram.

        Note:
            Underflow and overflow bins are excluded!

        Args:
            hist (ROOT.TH1): Input histogram.
        Returns:
            tuple: (x, y, errors) where x is the bin centers, y is the bin values, and
                errors are the sumw2 bin errors.
        """
        # Enable sumw2 if it's not already calculated
        if hist.GetSumw2N() == 0:
            hist.Sumw2(True)

        # Don't include overflow
        bin_edges = get_bin_edges_from_axis(hist.GetXaxis())
        # NOTE: The y value and bin error are stored with the hist, not the axis.
        y = np.array([
            hist.GetBinContent(i)
            for i in range(1,
                           hist.GetXaxis().GetNbins() + 1)
        ])
        errors = np.array(hist.GetSumw2())
        # Exclude the under/overflow bins
        errors = errors[1:-1]
        metadata = {}

        # Check for a TProfile.
        # In that case we need to retrieve the errors manually because the Sumw2() errors are
        # not the anticipated errors.
        if hasattr(hist, "BuildOptions"):
            errors = np.array([
                hist.GetBinError(i)
                for i in range(1,
                               hist.GetXaxis().GetNbins() + 1)
            ])
            # We expected errors squared
            errors = errors**2
        else:
            # Sanity check. If they don't match, something odd has almost certainly occurred.
            if not np.isclose(errors[0], hist.GetBinError(1)**2):
                raise ValueError(
                    "Sumw2 errors don't seem to represent bin errors!")

            # Retrieve the stats and store them in the metadata.
            # They are useful for calculating histogram properties (mean, variance, etc).
            stats = np.array([0, 0, 0, 0], dtype=np.float64)
            hist.GetStats(np.ctypeslib.as_ctypes(stats))
            # Return values are (each one is a single float):
            # [1], [2], [3], [4]
            # [1]: total_sum_w: Sum of weights (equal to np.sum(y) if unscaled)
            # [2]: total_sum_w2: Sum of weights squared (equal to np.sum(errors_squared) if unscaled)
            # [3]: total_sum_wx: Sum of w*x
            # [4}: total_sum_wx2: Sum of w*x*x
            metadata.update(_create_stats_dict_from_values(*stats))

        return (bin_edges, y, errors, metadata)