Esempio n. 1
0
    def _sample(self, n, limits: ZfitSpace):

        pdf = self.pdfs[0]
        # TODO: use real limits, currently not supported in binned sample
        sample = pdf.sample(n=n)

        edges = sample.space.binning.edges
        ndim = len(edges)
        edges = [znp.array(edge) for edge in edges]
        edges_flat = [znp.reshape(edge, [-1]) for edge in edges]
        lowers = [edge[:-1] for edge in edges_flat]
        uppers = [edge[1:] for edge in edges_flat]
        lowers_meshed = znp.meshgrid(*lowers, indexing="ij")
        uppers_meshed = znp.meshgrid(*uppers, indexing="ij")
        lowers_meshed_flat = [
            znp.reshape(lower_mesh, [-1]) for lower_mesh in lowers_meshed
        ]
        uppers_meshed_flat = [
            znp.reshape(upper_mesh, [-1]) for upper_mesh in uppers_meshed
        ]
        lower_flat = znp.stack(lowers_meshed_flat, axis=-1)
        upper_flat = znp.stack(uppers_meshed_flat, axis=-1)

        counts_flat = znp.reshape(sample.values(), (-1, ))
        counts_flat = tf.cast(counts_flat,
                              znp.int32)  # TODO: what if we have fractions?
        lower_flat_repeated = tf.repeat(lower_flat, counts_flat, axis=0)
        upper_flat_repeated = tf.repeat(upper_flat, counts_flat, axis=0)
        sample_unbinned = tf.random.uniform(
            (znp.sum(counts_flat), ndim),
            minval=lower_flat_repeated,
            maxval=upper_flat_repeated,
            dtype=self.dtype,
        )
        return sample_unbinned
Esempio n. 2
0
    def _rel_counts(self, x, norm):
        pdf = self.pdfs[0]
        edges = [znp.array(edge) for edge in self.axes.edges]
        edges_flat = [znp.reshape(edge, [-1]) for edge in edges]
        lowers = [edge[:-1] for edge in edges_flat]
        uppers = [edge[1:] for edge in edges_flat]
        lowers_meshed = znp.meshgrid(*lowers, indexing="ij")
        uppers_meshed = znp.meshgrid(*uppers, indexing="ij")
        shape = tf.shape(lowers_meshed[0])
        lowers_meshed_flat = [
            znp.reshape(lower_mesh, [-1]) for lower_mesh in lowers_meshed
        ]
        uppers_meshed_flat = [
            znp.reshape(upper_mesh, [-1]) for upper_mesh in uppers_meshed
        ]
        lower_flat = znp.stack(lowers_meshed_flat, axis=-1)
        upper_flat = znp.stack(uppers_meshed_flat, axis=-1)
        options = {"type": "bins"}

        @z.function
        def integrate_one(limits):
            l, u = tf.unstack(limits)
            limits_space = zfit.Space(obs=self.obs, limits=[l, u])
            return pdf.integrate(limits_space, norm=False, options=options)

        limits = znp.stack([lower_flat, upper_flat], axis=1)
        values = tf.vectorized_map(integrate_one, limits)
        values = znp.reshape(values, shape)
        if norm:
            values /= pdf.normalization(norm)
        return values
Esempio n. 3
0
    def to_unbinned(self):
        meshed_center = znp.meshgrid(*self.axes.centers, indexing="ij")
        flat_centers = [
            znp.reshape(center, (-1, )) for center in meshed_center
        ]
        centers = znp.stack(flat_centers, axis=-1)
        flat_weights = znp.reshape(self.values(), (-1, ))  # TODO: flow?
        space = self.space.copy(binning=None)
        from zfit import Data

        return Data.from_tensor(obs=space,
                                tensor=centers,
                                weights=flat_weights)
Esempio n. 4
0
    def _counts(self, x, norm):
        pdf = self.pdfs[0]
        edges = [znp.array(edge) for edge in self.axes.edges]
        edges_flat = [znp.reshape(edge, [-1]) for edge in edges]
        lowers = [edge[:-1] for edge in edges_flat]
        uppers = [edge[1:] for edge in edges_flat]
        lowers_meshed = znp.meshgrid(*lowers, indexing="ij")
        uppers_meshed = znp.meshgrid(*uppers, indexing="ij")
        shape = tf.shape(lowers_meshed[0])
        lowers_meshed_flat = [
            znp.reshape(lower_mesh, [-1]) for lower_mesh in lowers_meshed
        ]
        uppers_meshed_flat = [
            znp.reshape(upper_mesh, [-1]) for upper_mesh in uppers_meshed
        ]
        lower_flat = znp.stack(lowers_meshed_flat, axis=-1)
        upper_flat = znp.stack(uppers_meshed_flat, axis=-1)
        options = {"type": "bins"}

        if pdf.is_extended:

            @z.function
            def integrate_one(limits):
                l, u = tf.unstack(limits)
                limits_space = zfit.Space(obs=self.obs, limits=[l, u])
                return pdf.ext_integrate(limits_space,
                                         norm=False,
                                         options=options)

            missing_yield = False
        else:

            @z.function
            def integrate_one(limits):
                l, u = tf.unstack(limits)
                limits_space = zfit.Space(obs=self.obs, limits=[l, u])
                return pdf.integrate(limits_space, norm=False, options=options)

            missing_yield = True

        limits = znp.stack([lower_flat, upper_flat], axis=1)
        try:
            values = tf.vectorized_map(integrate_one, limits)[:, 0]
        except ValueError:
            values = tf.map_fn(integrate_one, limits)
        values = znp.reshape(values, shape)
        if missing_yield:
            values *= self.get_yield()
        if norm:
            values /= pdf.normalization(norm)
        return values
Esempio n. 5
0
def spline_interpolator(alpha, alphas, densities):
    alphas = alphas[None, :, None]
    shape = tf.shape(densities[0])
    densities_flat = [znp.reshape(density, [-1]) for density in densities]
    densities_flat = znp.stack(densities_flat, axis=0)
    alpha_shaped = znp.reshape(alpha, [1, -1, 1])
    y_flat = tfa.image.interpolate_spline(
        train_points=alphas,
        train_values=densities_flat[None, ...],
        query_points=alpha_shaped,
        order=2,
    )
    y_flat = y_flat[0, 0]
    y = tf.reshape(y_flat, shape)
    return y
Esempio n. 6
0
def func_integral_hermite(limits, norm, params, model):
    lower, upper = limits.limit1d
    lower_rescaled = model._polynomials_rescale(lower)
    upper_rescaled = model._polynomials_rescale(upper)

    lower = z.convert_to_tensor(lower_rescaled)
    upper = z.convert_to_tensor(upper_rescaled)

    # the integral of hermite is a hermite_ni. We add the ni to the coeffs.
    coeffs = {"c_0": z.constant(0.0, dtype=model.dtype)}

    for name, coeff in params.items():
        ip1_coeff = int(name.split("_", 1)[-1]) + 1
        coeffs[f"c_{ip1_coeff}"] = coeff / z.convert_to_tensor(
            ip1_coeff * 2.0, dtype=model.dtype)
    coeffs = convert_coeffs_dict_to_list(coeffs)

    def indefinite_integral(limits):
        return hermite_shape(x=limits, coeffs=coeffs)

    integral = indefinite_integral(upper) - indefinite_integral(lower)
    integral = znp.reshape(integral, newshape=())
    integral *= 0.5 * model.space.area()  # rescale back to whole width

    return integral
Esempio n. 7
0
 def _check_init_values(self, space, values, variances):
     value_shape = tf.shape(values)
     edges_shape = znp.array([
         tf.shape(znp.reshape(edge, (-1, )))[0]
         for edge in space.binning.edges
     ])
     values_rank = value_shape.shape[0]
     if variances is not None:
         variances_shape = tf.shape(variances)
         variances_rank = variances_shape.shape[0]
         if values_rank != variances_rank:
             raise ShapeIncompatibleError(
                 f"Values {values} and variances {variances} differ in rank: {values_rank} vs {variances_rank}"
             )
         tf.assert_equal(
             variances_shape,
             value_shape,
             message=f"Variances and values do not have the same shape:"
             f" {variances_shape} vs {value_shape}",
         )
     binning_rank = len(space.binning.edges)
     if binning_rank != values_rank:
         raise ShapeIncompatibleError(
             f"Values and binning  differ in rank: {values_rank} vs {binning_rank}"
         )
     tf.assert_equal(
         edges_shape - 1,
         value_shape,
         message=f"Edges (minus one) and values do not have the same shape:"
         f" {edges_shape} vs {value_shape}",
     )
Esempio n. 8
0
def func_integral_chebyshev2(limits, norm, params, model):
    lower, upper = limits.limit1d
    lower_rescaled = model._polynomials_rescale(lower)
    upper_rescaled = model._polynomials_rescale(upper)

    lower = z.convert_to_tensor(lower_rescaled)
    upper = z.convert_to_tensor(upper_rescaled)

    # the integral of cheby2_ni is a cheby1_ni+1/(n+1). We add the (n+1) to the coeffs. The cheby1 shape makes
    # the sum for us.
    coeffs_cheby1 = {"c_0": z.constant(0.0, dtype=model.dtype)}

    for name, coeff in params.items():
        n_plus1 = int(name.split("_", 1)[-1]) + 1
        coeffs_cheby1[f"c_{n_plus1}"] = coeff / z.convert_to_tensor(
            n_plus1, dtype=model.dtype)
    coeffs_cheby1 = convert_coeffs_dict_to_list(coeffs_cheby1)

    def indefinite_integral(limits):
        return chebyshev_shape(x=limits, coeffs=coeffs_cheby1)

    integral = indefinite_integral(upper) - indefinite_integral(lower)
    integral = znp.reshape(integral, newshape=())
    integral *= 0.5 * model.space.area()  # rescale back to whole width

    return integral
Esempio n. 9
0
 def _counts_with_modifiers(self, x, norm):
     values = self.pdfs[0].counts(x, norm=norm)
     modifiers = list(self._binwise_modifiers.values())
     if modifiers:
         sysshape_flat = tf.stack(modifiers)
         modifiers = znp.reshape(sysshape_flat, values.shape)
         values = values * modifiers
     return values
Esempio n. 10
0
 def sumfunc(params):
     values = self.pdfs[0].counts(obs)
     sysshape = list(params.values())
     if sysshape:
         sysshape_flat = tf.stack(sysshape)
         sysshape = znp.reshape(sysshape_flat, values.shape)
         values = values * sysshape
     return znp.sum(values)
Esempio n. 11
0
 def _ext_pdf(self, x, norm):
     if not self._automatically_extended:
         raise SpecificFunctionNotImplemented
     pdf = self.pdfs[0]
     density = pdf.ext_pdf(x.space, norm=norm)
     density_flat = znp.reshape(density, (-1, ))
     centers_list = znp.meshgrid(*pdf.space.binning.centers, indexing="ij")
     centers_list_flat = [
         znp.reshape(cent, (-1, )) for cent in centers_list
     ]
     centers = znp.stack(centers_list_flat, axis=-1)
     # [None, :, None]  # TODO: only 1 dim now
     probs = tfa.image.interpolate_spline(
         train_points=centers[None, ...],
         train_values=density_flat[None, :, None],
         query_points=x.value()[None, ...],
         order=self.order,
     )
     return probs[0, ..., 0]
Esempio n. 12
0
def unbinned_to_binindex(data, space, flow=False):
    if flow:
        warnings.warn(
            "Flow currently not fully supported. Values outside the edges are all 0."
        )
    values = [znp.reshape(data.value(ob), (-1, )) for ob in space.obs]
    edges = [znp.reshape(edge, (-1, )) for edge in space.binning.edges]
    bins = [
        tfp.stats.find_bins(x=val, edges=edge)
        for val, edge in zip(values, edges)
    ]
    stacked_bins = znp.stack(bins, axis=-1)
    if flow:
        stacked_bins += 1
        bin_is_nan = tf.math.is_nan(stacked_bins)
        zeros = znp.zeros_like(stacked_bins)
        binindices = znp.where(bin_is_nan, zeros, stacked_bins)
        stacked_bins = znp.asarray(binindices, dtype=znp.int32)
    return stacked_bins
Esempio n. 13
0
def legendre_integral(
    limits: ztyping.SpaceType,
    norm: ztyping.SpaceType,
    params: list[zfit.Parameter],
    model: RecursivePolynomial,
):
    """Recursive integral of Legendre polynomials."""
    lower, upper = limits.limit1d
    lower_rescaled = model._polynomials_rescale(lower)
    upper_rescaled = model._polynomials_rescale(upper)
    # if np.allclose((lower_rescaled, upper_rescaled), (-1, 1)):
    #     return z.constant(2.)  #

    lower = z.convert_to_tensor(lower_rescaled)
    upper = z.convert_to_tensor(upper_rescaled)

    integral_0 = model.params[f"c_0"] * (upper - lower)  # if polynomial 0 is 1
    if model.degree == 0:
        integral = integral_0
    else:

        def indefinite_integral(limits):
            max_degree = (
                model.degree + 1
            )  # needed +1 for integral, max poly in term for n is n+1
            polys = do_recurrence(
                x=limits,
                polys=legendre_polys,
                degree=max_degree,
                recurrence=legendre_recurrence,
            )
            one_limit_integrals = []
            for degree in range(1, max_degree):
                coeff = model.params[f"c_{degree}"]
                one_limit_integrals.append(
                    coeff * (polys[degree + 1] - polys[degree - 1]) /
                    (2.0 * (z.convert_to_tensor(degree)) + 1))
            return z.reduce_sum(one_limit_integrals, axis=0)

        integral = indefinite_integral(upper) - indefinite_integral(
            lower) + integral_0
        integral = znp.reshape(integral, newshape=())
    integral *= 0.5 * model.space.area()  # rescale back to whole width

    return integral
Esempio n. 14
0
def test_unbinned_data2D():
    n = 751
    gauss, gauss_binned, obs, obs_binned = create_gauss2d_binned(n, 50)

    data = znp.random.uniform([-5, 50], [10, 600], size=(1000, 2))
    y_binned = gauss_binned.pdf(data)
    y_true = gauss.pdf(data)
    max_error = np.max(y_true) / 10
    np.testing.assert_allclose(y_true, y_binned, atol=max_error)

    centers = obs_binned.binning.centers
    X, Y = znp.meshgrid(*centers, indexing="ij")
    centers = znp.stack([znp.reshape(t, (-1,)) for t in (X, Y)], axis=-1)
    ycenter_binned = gauss_binned.pdf(centers)
    ycenter_true = gauss.pdf(centers)
    np.testing.assert_allclose(ycenter_binned, ycenter_true, atol=max_error / 10)

    # for the extended case
    y_binned_ext = gauss_binned.ext_pdf(data)
    y_true_ext = gauss.ext_pdf(data)
    max_error_ext = np.max(y_true_ext) / 10
    np.testing.assert_allclose(y_true_ext, y_binned_ext, atol=max_error_ext)

    ycenter_binned_ext = gauss_binned.ext_pdf(centers)
    ycenter_true_ext = gauss.ext_pdf(centers)
    np.testing.assert_allclose(
        ycenter_binned_ext, ycenter_true_ext, atol=max_error_ext / 10
    )

    x_outside = znp.array([[-7.0, 55], [3.0, 13], [2, 150], [12, 30], [14, 1000]])
    y_outside = gauss_binned.pdf(x_outside)
    assert y_outside[0] == 0
    assert y_outside[1] == 0
    assert y_outside[2] > 0
    assert y_outside[3] == 0
    assert y_outside[4] == 0

    y_outside_ext = gauss_binned.ext_pdf(x_outside)
    assert y_outside_ext[0] == 0
    assert y_outside_ext[1] == 0
    assert y_outside_ext[2] > 0
    assert y_outside_ext[3] == 0
    assert y_outside_ext[4] == 0
Esempio n. 15
0
        def create_covariance(mu, sigma):
            mu = z.convert_to_tensor(mu)
            sigma = z.convert_to_tensor(sigma)  # TODO (Mayou36): fix as above?
            params_tensor = z.convert_to_tensor(params)

            if sigma.shape.ndims > 1:
                covariance = sigma  # TODO: square as well?
            elif sigma.shape.ndims == 1:
                covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0))
            else:
                sigma = znp.reshape(sigma, [1])
                covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0))

            if (not params_tensor.shape[0] == mu.shape[0] ==
                    covariance.shape[0] == covariance.shape[1]):
                raise ShapeIncompatibleError(
                    f"params_tensor, observation and uncertainty have to have the"
                    " same length. Currently"
                    f"param: {params_tensor.shape[0]}, mu: {mu.shape[0]}, "
                    f"covariance (from uncertainty): {covariance.shape[0:2]}")
            return covariance
Esempio n. 16
0
def func_integral_chebyshev1(limits, norm, params, model):
    lower, upper = limits.rect_limits
    lower_rescaled = model._polynomials_rescale(lower)
    upper_rescaled = model._polynomials_rescale(upper)

    lower = z.convert_to_tensor(lower_rescaled)
    upper = z.convert_to_tensor(upper_rescaled)

    integral = model.params[f"c_0"] * (
        upper - lower)  # if polynomial 0 is defined as T_0 = 1
    if model.degree >= 1:
        integral += (model.params[f"c_1"] * 0.5 * (upper**2 - lower**2)
                     )  # if polynomial 0 is defined as T_0 = 1
    if model.degree >= 2:

        def indefinite_integral(limits):
            max_degree = model.degree + 1
            polys = do_recurrence(
                x=limits,
                polys=chebyshev_polys,
                degree=max_degree,
                recurrence=chebyshev_recurrence,
            )
            one_limit_integrals = []
            for degree in range(2, max_degree):
                coeff = model.params[f"c_{degree}"]
                n_float = z.convert_to_tensor(degree)
                integral = n_float * polys[degree + 1] / (z.square(
                    n_float) - 1) - limits * polys[degree] / (n_float - 1)
                one_limit_integrals.append(coeff * integral)
            return z.reduce_sum(one_limit_integrals, axis=0)

        integral += indefinite_integral(upper) - indefinite_integral(lower)
        integral = znp.reshape(integral, newshape=())
    integral *= 0.5 * model.space.area()  # rescale back to whole width
    integral = tf.gather(integral, indices=0, axis=-1)
    return integral
Esempio n. 17
0
def func_integral_laguerre(limits, norm, params: dict, model):
    """The integral of the simple laguerre polynomials.

    Defined as :math:`\\int L_{n} = (-1) L_{n+1}^{(-1)}` with :math:`L^{(\alpha)}` the generalized Laguerre polynom.

    Args:
        limits:
        norm:
        params:
        model:

    Returns:
    """
    lower, upper = limits.limit1d
    lower_rescaled = model._polynomials_rescale(lower)
    upper_rescaled = model._polynomials_rescale(upper)

    lower = z.convert_to_tensor(lower_rescaled)
    upper = z.convert_to_tensor(upper_rescaled)

    # The laguerre shape makes the sum for us. setting the 0th coeff to 0, since no -1 term exists.
    coeffs_laguerre_nup = {
        f'c_{int(n.split("_", 1)[-1]) + 1}': c
        for i, (n, c) in enumerate(params.items())
    }  # increase n -> n+1 of naming
    coeffs_laguerre_nup["c_0"] = tf.constant(0.0, dtype=model.dtype)
    coeffs_laguerre_nup = convert_coeffs_dict_to_list(coeffs_laguerre_nup)

    def indefinite_integral(limits):
        return -1 * laguerre_shape_alpha_minusone(x=limits,
                                                  coeffs=coeffs_laguerre_nup)

    integral = indefinite_integral(upper) - indefinite_integral(lower)
    integral = znp.reshape(integral, newshape=())
    integral *= 0.5 * model.space.area()  # rescale back to whole width
    return integral
Esempio n. 18
0
    def _unnormalized_pdf(self, x):

        lower_func, upper_func = self._conv_limits["func"]
        nbins_func = self._conv_limits["nbins_func"]
        x_funcs = tf.linspace(lower_func, upper_func,
                              tf.cast(nbins_func, tf.int32))

        lower_kernel, upper_kernel = self._conv_limits["kernel"]
        nbins_kernel = self._conv_limits["nbins_kernel"]
        x_kernels = tf.linspace(lower_kernel, upper_kernel,
                                tf.cast(nbins_kernel, tf.int32))

        x_func = tf.meshgrid(*tf.unstack(x_funcs, axis=-1), indexing="ij")
        x_func = znp.transpose(x_func)
        x_func_flatish = znp.reshape(x_func, (-1, self.n_obs))
        data_func = Data.from_tensor(tensor=x_func_flatish, obs=self.obs)

        x_kernel = tf.meshgrid(*tf.unstack(x_kernels, axis=-1), indexing="ij")
        x_kernel = znp.transpose(x_kernel)
        x_kernel_flatish = znp.reshape(x_kernel, (-1, self.n_obs))
        data_kernel = Data.from_tensor(tensor=x_kernel_flatish, obs=self.obs)

        y_func = self.pdfs[0].pdf(data_func, norm=False)
        y_kernel = self.pdfs[1].pdf(data_kernel, norm=False)

        func_dims = [nbins_func] * self.n_obs
        kernel_dims = [nbins_kernel] * self.n_obs

        y_func = znp.reshape(y_func, func_dims)
        y_kernel = znp.reshape(y_kernel, kernel_dims)

        # flip the kernel to use the cross-correlation called `convolution function from TF
        # convolution = cross-correlation with flipped kernel
        # this introduces a shift and has to be corrected when interpolating/x_func
        # because the convolution has to be independent of the kernes **limits**
        # We assume they are symmetric when doing the FFT, so shift them back.
        y_kernel = tf.reverse(y_kernel, axis=range(self.n_obs))
        kernel_shift = (upper_kernel + lower_kernel) / 2
        x_func += kernel_shift
        lower_func += kernel_shift
        upper_func += kernel_shift

        # make rectangular grid
        y_func_rect = znp.reshape(y_func, func_dims)
        y_kernel_rect = znp.reshape(y_kernel, kernel_dims)

        # needed for multi dims?
        # if self.n_obs == 2:
        #     y_kernel_rect = tf.linalg.adjoint(y_kernel_rect)

        # get correct shape for tf.nn.convolution
        y_func_rect_conv = znp.reshape(y_func_rect, (1, *func_dims, 1))
        y_kernel_rect_conv = znp.reshape(y_kernel_rect, (*kernel_dims, 1, 1))

        conv = tf.nn.convolution(
            input=y_func_rect_conv,
            filters=y_kernel_rect_conv,
            strides=1,
            padding="SAME",
        )

        # needed for multidims?
        # if self.n_obs == 2:
        #     conv = tf.linalg.adjoint(conv[0, ..., 0])[None, ..., None]
        # conv = scipy.signal.convolve(
        #     y_func_rect,
        #     y_kernel_rect,
        #     mode='same'
        # )[None, ..., None]
        train_points = znp.expand_dims(x_func, axis=0)
        query_points = znp.expand_dims(x.value(), axis=0)
        if self.conv_interpolation == "spline":
            conv_points = znp.reshape(conv, (1, -1, 1))
            prob = tfa.image.interpolate_spline(
                train_points=train_points,
                train_values=conv_points,
                query_points=query_points,
                order=self._conv_spline_order,
            )
            prob = prob[0, ..., 0]
        elif self.conv_interpolation == "linear":
            prob = tfp.math.batch_interp_regular_nd_grid(
                x=query_points[0],
                x_ref_min=lower_func,
                x_ref_max=upper_func,
                y_ref=conv[0, ..., 0],
                # y_ref=tf.reverse(conv[0, ..., 0], axis=[0]),
                axis=-self.n_obs,
            )
            prob = prob[0]

        return prob
Esempio n. 19
0
def cut_edges_and_bins(
    edges: Iterable[znp.array], limits: ZfitSpace, axis=None, unscaled=None
) -> tuple[list[znp.array], tuple[znp.array, znp.array], list | None]:
    """Cut the *edges* according to *limits* and calculate the bins inside.

    The edges within limits are calculated and returned together with the corresponding bin indices. The indices
    mark the lowest and the highest index of the edges that are returned. Additionally, the unscaled edges are returned.

    If the limits are between two edges, this will be treated as the new edge. If the limits are outside the edges,
    all edges in this direction will be returned (but not extended to the limit). For example:

    [0, 0.5, 1., 1.5, 2.] and the limits (0.8, 3.) will return [0.8, 1., 1.5, 2.], ([1], [4])

    .. code-block::

        cut_edges_and_bins([[0., 0.5, 1., 1.5, 2.]], ([[0.8]], [[3]]))



    Args:
        edges: Iterable of tensor-like objects that describe the edges of a histogram. Every object should have rank n
            (where n is the length of *edges*) but only have the dimension i filled out. These are
            tensors that are ready to be broadcasted together.
        limits: The limits that will be used to confine the edges


    Returns:
        edges, (lower bins, upper bins), unscaled_edges:  The edges and the bins are returned.
            The upper bin number corresponds to
            the highest bin which was still (partially) inside the limits **plus one** (so it's the index of the
            edge that is right outside). The unscaled edges are like *edges* but the last edge is the edge
            that is lying not inside anymore, so the actual edge of the last bin number returend.
            This can be used to determine the fraction cut away.
    """
    if axis is not None:
        axis = convert_to_container(axis)
    if unscaled is None:
        unscaled = False
    if unscaled:
        cut_unscaled_edges = []
    else:
        cut_unscaled_edges = None
    cut_scaled_edges = []

    all_lower_bins = []
    all_upper_bins = []
    if isinstance(limits, ZfitSpace):
        lower, upper = limits.limits
    else:
        lower, upper = limits
        lower = znp.asarray(lower)
        upper = znp.asarray(upper)
    lower_all = lower[0]
    upper_all = upper[0]
    rank = len(edges)
    current_axis = 0
    for i, edge in enumerate(edges):
        edge = znp.asarray(edge)
        edge = znp.reshape(edge, (-1,))
        if axis is None or i in axis:

            lower_i = lower_all[current_axis, None]
            edge_minimum = edge[0]
            # edge_minimum = tf.gather(edge, indices=0, axis=i)
            lower_i = znp.maximum(lower_i, edge_minimum)
            upper_i = upper_all[current_axis, None]
            edge_maximum = edge[-1]
            # edge_maximum = tf.gather(edge, indices=tf.shape(edge)[i] - 1, axis=i)
            upper_i = znp.minimum(upper_i, edge_maximum)
            # we get the bins that are just one too far. Then we update this whole bin tensor with the actual edge.
            # The bins index is the index below the value.
            lower_bin_float = tfp.stats.find_bins(
                lower_i, edge, extend_lower_interval=True, extend_upper_interval=True
            )
            lower_bin = tf.reshape(tf.cast(lower_bin_float, dtype=znp.int32), [-1])
            # lower_bins = tf.tensor_scatter_nd_update(zero_bins, [[i]], lower_bin)
            # +1 below because the outer bin is searched, meaning the one that is higher than the value

            upper_bin_float = tfp.stats.find_bins(
                upper_i, edge, extend_lower_interval=True, extend_upper_interval=True
            )
            upper_bin = tf.reshape(tf.cast(upper_bin_float, dtype=znp.int32), [-1]) + 1
            size = upper_bin - lower_bin
            new_edge = tf.slice(
                edge, lower_bin, size + 1
            )  # +1 because stop is exclusive
            new_edge = tf.tensor_scatter_nd_update(
                new_edge, [tf.constant([0]), size], [lower_i[0], upper_i[0]]
            )

            if unscaled:
                new_edge_unscaled = tf.slice(
                    edge, lower_bin, size + 1
                )  # +1 because stop is exclusive

            current_axis += 1
        else:
            lower_bin = [0]
            upper_bin = znp.asarray([edge.shape[0] - 1], dtype=znp.int32)
            new_edge = edge
            if unscaled:
                new_edge_unscaled = edge
        new_shape = [1] * rank
        new_shape[i] = -1
        new_edge = znp.reshape(new_edge, new_shape)
        all_lower_bins.append(lower_bin)
        all_upper_bins.append(upper_bin)
        cut_scaled_edges.append(new_edge)
        if unscaled:
            new_edge_unscaled = znp.reshape(new_edge_unscaled, new_shape)
            cut_unscaled_edges.append(new_edge_unscaled)

    # partial = axis is not None and len(axis) < rank
    #
    # if partial:
    #     scaled_edges_full = list(edges)
    #     for edge, ax in zip(cut_scaled_edges, axis):
    #         scaled_edges_full[ax] = edge
    #     scaled_edges = scaled_edges_full
    #     indices = tf.convert_to_tensor(axis)[:, None]
    #     lower_bins = tf.scatter_nd(indices, lower_bins, shape=(ndims,))
    #     upper_bins = tf.tensor_scatter_nd_update(tf.convert_to_tensor(values.shape),
    #                                              indices, upper_bins)
    # lower_bins_indices = tf.stack([lower_bins, dims], axis=-1)
    # upper_bins_indices = tf.stack([upper_bins, dims], axis=-1)
    # all_lower_bins = tf.cast(znp.sum(all_lower_bins, axis=0), dtype=znp.int32)
    all_lower_bins = tf.concat(all_lower_bins, axis=0)
    all_upper_bins = tf.concat(all_upper_bins, axis=0)
    return cut_scaled_edges, (all_lower_bins, all_upper_bins), cut_unscaled_edges