Ejemplo n.º 1
0
def test_posterior_maximum(xp):
    """
    Test calculation of posterior maximum
    """

    #
    # 1D predictions
    #

    quantiles = arange(xp, 0.1, 0.91, 0.1)
    y_pred = [
        arange(xp, 2.0, 4.1, 1.0),
        arange(xp, 4.4, 4.51, 0.1),
        arange(xp, 5.0, 8.1, 1.0)
    ]
    y_pred = concatenate(xp, y_pred, 0)

    pm = posterior_maximum(y_pred, quantiles)

    assert np.isclose(pm, 4.45)

    #
    # 2D predictions
    #

    quantiles = arange(xp, 0.1, 0.91, 0.1)
    y_pred = eo.repeat(y_pred, 'q -> h q', h=10)

    pm = posterior_maximum(y_pred, quantiles)
    assert np.isclose(pm[0], 4.45)

    #
    # 3D predictions
    #

    quantiles = arange(xp, 0.1, 0.91, 0.1)
    y_pred = [
        arange(xp, 2.0, 4.1, 1.0),
        arange(xp, 4.4, 4.51, 0.1),
        arange(xp, 5.0, 8.1, 1.0)
    ]
    y_pred = concatenate(xp, y_pred, 0)
    y_pred = eo.repeat(y_pred, 'q -> h q w', h=10, w=10)

    pm = posterior_maximum(y_pred, quantiles)
    assert np.isclose(pm[0, 0], 4.45)

    quantiles = arange(xp, 0.1, 0.91, 0.1)
    y_pred = [
        arange(xp, 2.0, 4.1, 1.0),
        arange(xp, 4.4, 4.51, 0.1),
        arange(xp, 5.0, 8.1, 1.0)
    ]
    y_pred = concatenate(xp, y_pred, 0)
    y_pred = eo.repeat(y_pred, 'q -> h w q', h=10, w=10)

    pm = posterior_maximum(y_pred, quantiles, quantile_axis=-1)
    assert np.isclose(pm[0, 0], 4.45)
Ejemplo n.º 2
0
def posterior_quantiles(y_pred, bins, quantiles, bin_axis=1):

    if len(y_pred.shape) == 1:
        bin_axis = 0
    n_y = y_pred.shape[bin_axis]
    n_b = len(bins)
    _check_dimensions(n_y, n_b)
    xp = get_array_module(y_pred)

    y_cdf = posterior_cdf(y_pred, bins, bin_axis=bin_axis)

    n = len(y_pred.shape)
    dx = bins[1:] - bins[:-1]
    x_shape = [1] * n
    x_shape[bin_axis] = numel(bins)
    dx = pad_zeros_left(xp, dx, 1, 0)
    dx = reshape(xp, dx, x_shape)

    y_qs = []
    for q in quantiles:
        mask = as_type(xp, y_cdf <= q, y_cdf)
        y_q = bins[0] + xp.sum(mask * dx, bin_axis)
        y_q = expand_dims(xp, y_q, bin_axis)
        y_qs.append(y_q)

    y_q = concatenate(xp, y_qs, bin_axis)
    return y_q
Ejemplo n.º 3
0
def test_add(xp):
    """
    Ensure that the maximum of the posterior is computed correctly.
    """

    #
    # 1D predictions
    #

    bins = arange(xp, 0.0, 10.001, 1.0)
    bins_out = arange(xp, 0.5, 10.001, 1.0)

    y_pred = [
        0.5 * xp.ones(2),
        xp.zeros(8)
    ]
    y_pred = concatenate(xp, y_pred, 0)
    y_s = add(y_pred, bins, y_pred, bins, bins_out)

    assert np.isclose(y_s[0], 0.25)
    assert np.isclose(y_s[1], 0.5)
    assert np.isclose(y_s[2], 0.25)

    #
    # 2D predictions
    #

    y_pred = eo.repeat(y_pred, 'q -> h q', h=10)

    y_s = add(y_pred, bins, y_pred, bins, bins_out, bin_axis=1)
    assert np.isclose(y_s[0, 0], 0.25)
    assert np.isclose(y_s[0, 1], 0.5)
    assert np.isclose(y_s[0, 2], 0.25)
Ejemplo n.º 4
0
def test_concatenate(backend):
    """
    Ensures that concatenation of array yields tensor with the expected size.
    """
    array_1 = backend.ones((10, 1))
    array_2 = backend.ones((10, 2))
    result = concatenate(backend, [array_1, array_2], 1)
    assert numel(result) == 30
Ejemplo n.º 5
0
def posterior_quantiles(y_pred, quantiles, new_quantiles, quantile_axis=1):
    r"""
    Computes the median of the posterior distribution defined by an array
    of predicted quantiles.

    Args:
        y_pred: A rank-k tensor of predicted quantiles with the quantiles
             located along the axis given by ``quantile_axis``.
        quantiles: The quantile fractions corresponding to the quantiles
             located along the quantile axis.
        quantile_axis: The axis along which the quantiles are located.

    Returns:

        Rank k-1 tensor containing the posterior median for the provided inputs.
    """
    if len(y_pred.shape) == 1:
        quantile_axis = 0
    xp = get_array_module(y_pred)

    n = len(y_pred.shape)
    indices = arange(xp, 0, len(quantiles), 1.0)
    selection = [slice(0, None)] * n

    y_qs = []

    for q in new_quantiles:
        mask = (quantiles[1:] > q) * (quantiles[:-1] <= q)
        index = indices[:-1][mask]
        if len(index) == 0:
            if quantiles[0] < q:
                selection[quantile_axis] = 0
                selection_l = tuple(selection)
                return y_pred[selection_l]
            else:
                selection[quantile_axis] = -1
                selection_r = tuple(selection)
                return y_pred[selection_r]

        index = int(index[0])
        d = quantiles[index + 1] - quantiles[index]
        w_l = (quantiles[index + 1] - q) / d
        w_r = (q - quantiles[index]) / d

        selection = [slice(0, None)] * n
        selection[quantile_axis] = index
        selection_l = tuple(selection)
        selection[quantile_axis] = index + 1
        selection_r = tuple(selection)

        y_q = w_l * y_pred[selection_l] + w_r * y_pred[selection_r]
        y_q = expand_dims(xp, y_q, quantile_axis)
        y_qs.append(y_q)

    return concatenate(xp, y_qs, quantile_axis)
Ejemplo n.º 6
0
    def __call__(self, x, dist_axis=1):
        """
        Evaluate the a priori.

        Args:
            x: Tensor containing the values at which to evaluate the a priori.
            dist_axis: The axis along which the tensor x is sorted.

        Returns;
            Tensor with the same size as 'x' containing the values of the a priori
            at 'x' obtained by linear interpolation.
        """
        if len(x.shape) == 1:
            dist_axis = 0
        xp = get_array_module(x)
        n_dims = len(x.shape)

        n = x.shape[dist_axis]
        x_index = [slice(0, None)] * n_dims
        x_index[dist_axis] = 0

        selection_l = [slice(0, None)] * n_dims
        selection_l[dist_axis] = slice(0, -1)
        selection_l = tuple(selection_l)
        selection_r = [slice(0, None)] * n_dims
        selection_r[dist_axis] = slice(1, None)
        selection_r = tuple(selection_r)

        r_shape = [1] * n_dims
        r_shape[dist_axis] = -1
        r_x = self.x.reshape(r_shape)
        r_y = self.y.reshape(r_shape)

        r_x_l = r_x[selection_l]
        r_x_r = r_x[selection_r]
        r_y_l = r_y[selection_l]
        r_y_r = r_y[selection_r]

        rs = []

        for i in range(0, n):
            x_index[dist_axis] = slice(i, i + 1)
            index = tuple(x_index)
            x_i = x[index]

            mask = as_type(xp, (r_x_l < x_i) * (r_x_r >= x_i), x_i)
            r = r_y_l * (r_x_r - x_i) * mask
            r += r_y_r * (x_i - r_x_l) * mask
            r /= mask * (r_x_r - r_x_l) + (1.0 - mask)
            r = expand_dims(xp, r.sum(dist_axis), dist_axis)
            rs.append(r)

        r = concatenate(xp, rs, dist_axis)
        return r
Ejemplo n.º 7
0
def posterior_quantiles(y_pdf, bins, quantiles, bin_axis=1):
    """
    Calculate posterior quantiles from predicted PDFs.

    Args:
        y_pdf: Tensor containing the predicted PDFs.
        bins: The bin-boundaries corresponding to the predictions.
        quantiles: List containing the quantiles fractions of the quantiles
             to compute.
        bin_axis: The index of the tensor axis which contains the predictions
            for each bin.

    Return:
        Tensor with same rank as ``y_pdf`` but with the values
        the values along ``bin_axis`` replaced with the quantiles
        of the predicted distributions.
    """
    if len(y_pdf.shape) == 1:
        bin_axis = 0
    n_y = y_pdf.shape[bin_axis]
    n_b = len(bins)
    n_dims = len(y_pdf.shape)

    _check_dimensions(n_y, n_b)
    xp = get_array_module(y_pdf)

    y_cdf = posterior_cdf(y_pdf, bins, bin_axis=bin_axis)

    n = len(y_pdf.shape)
    dx = bins[1:] - bins[:-1]
    x_shape = [1] * n
    x_shape[bin_axis] = numel(bins)
    dx = pad_zeros_left(xp, dx, 1, 0)
    dx = reshape(xp, dx, x_shape)

    selection = [slice(0, None)] * n_dims
    selection[bin_axis] = slice(0, -1)
    selection_l = tuple(selection)
    selection[bin_axis] = slice(1, None)
    selection_r = tuple(selection)
    cdf_l = y_cdf[selection_l]
    cdf_r = y_cdf[selection_r]
    d_cdf = cdf_r - cdf_l

    shape = [1] * n_dims
    shape[bin_axis] = -1
    bins_l = bins.reshape(shape)[selection_l]
    bins_r = bins.reshape(shape)[selection_r]

    y_qs = []
    for q in quantiles:
        mask_l = as_type(xp, cdf_l <= q, cdf_l)
        mask_r = as_type(xp, cdf_r > q, cdf_l)
        mask = mask_l * mask_r

        d_q = q - expand_dims(xp, (cdf_l * mask).sum(bin_axis), bin_axis)

        result = (d_q * bins_r + (d_cdf - d_q) * bins_l) * mask
        result = result / (d_cdf + as_type(xp, d_cdf < 1e-6, d_cdf))
        result = result.sum(bin_axis)
        result = result + bins[-1] * (1.0 - as_type(xp, mask_r.sum(bin_axis) > 0, mask))
        result = result + bins[0] * (1.0 - as_type(xp, mask_l.sum(bin_axis) > 0, mask))
        result = expand_dims(xp, result, bin_axis)

        y_qs.append(result)

    y_q = concatenate(xp, y_qs, bin_axis)
    return y_q
Ejemplo n.º 8
0
def test_posterior_maximum(xp):
    """
    Ensure that the maximum of the posterior is computed correctly.
    """

    #
    # 1D predictions
    #

    bins = arange(xp, 0.0, 10.001, 1.0)

    y_pred = [
        xp.ones(4),
        to_array(xp, [2.0]),
        xp.ones(4)
    ]
    y_pred = concatenate(xp, y_pred, 0)
    pm = posterior_maximum(y_pred, bins)
    assert np.isclose(pm, 4.5)

    #
    # 2D predictions
    #

    bins = arange(xp, 0.0, 10.001, 1.0)

    y_pred = [
        xp.ones(4),
        to_array(xp, [2.0]),
        xp.ones(4)
    ]
    y_pred = concatenate(xp, y_pred, 0)
    y_pred = eo.repeat(y_pred, 'q -> h q', h=10)

    pm = posterior_maximum(y_pred, bins)
    assert np.isclose(pm[0], 4.5)

    #
    # 3D predictions
    #

    bins = arange(xp, 0.0, 10.001, 1.0)

    y_pred = [
        xp.ones(4),
        to_array(xp, [2.0]),
        xp.ones(4)
    ]
    y_pred = concatenate(xp, y_pred, 0)
    y_pred = eo.repeat(y_pred, 'q -> h q w', h=10, w=10)

    pm = posterior_maximum(y_pred, bins)
    assert np.isclose(pm[0, 0], 4.5)

    bins = arange(xp, 0.0, 10.001, 1.0)

    y_pred = [
        xp.ones(4),
        to_array(xp, [2.0]),
        xp.ones(4)
    ]
    y_pred = concatenate(xp, y_pred, 0)
    y_pred = eo.repeat(y_pred, 'q -> h w q', h=10, w=10)

    pm = posterior_maximum(y_pred, bins, -1)
    assert np.isclose(pm[0, 0], 4.5)
Ejemplo n.º 9
0
def pdf(y_pred, quantiles, quantile_axis=1):
    """
    Calculate probability density function (PDF) of the posterior distribution
    defined by predicted quantiles.

    The PDF is approximated by computing the derivative of the cumulative
    distribution function (CDF), which is obtained from by fitting a piece-wise
    function to the predicted quantiles and corresponding quantile fractions.

    Args:
        y_pred: Tensor containing the predicted quantiles along the quantile
            axis.
        quantiles: The quantile fractions corresponding to the predicted
            quantiles in y_pred.
        quantile_axis: The axis of y_pred along which the predicted
            quantiles are located.

    Returns:
        Tuple ``(x_pdf, y_pdf)`` consisting of two arrays. ``x_pdf``
        corresponds to the x-values of the PDF. ``y_pdf`` corresponds
        to the y-values of the PDF.
    """
    if len(y_pred.shape) == 1:
        quantile_axis = 0

    xp = get_array_module(y_pred)

    x_cdf, y_cdf = cdf(y_pred, quantiles, quantile_axis=quantile_axis)
    output_shape = list(x_cdf.shape)
    output_shape[quantile_axis] += 1
    n_dims = len(output_shape)

    #
    # Assemble x-tensor
    #

    selection_l = [slice(0, None)] * n_dims
    selection_l[quantile_axis] = slice(0, -1)
    selection_l = tuple(selection_l)
    selection_r = [slice(0, None)] * n_dims
    selection_r[quantile_axis] = slice(1, None)
    selection_r = tuple(selection_r)
    x_pdf = 0.5 * (x_cdf[selection_l] + x_cdf[selection_r])

    selection = [slice(0, None)] * n_dims
    selection[quantile_axis] = 0
    x_pdf_l = x_cdf[tuple(selection)]
    x_pdf_l = expand_dims(xp, x_pdf_l, quantile_axis)

    selection[quantile_axis] = -1
    x_pdf_r = x_cdf[tuple(selection)]
    x_pdf_r = expand_dims(xp, x_pdf_r, quantile_axis)

    x_pdf = concatenate(xp, [x_pdf_l, x_pdf, x_pdf_r], quantile_axis)

    #
    # Assemble y-tensor
    #

    y_pdf = 1.0 / (x_cdf[selection_r] - x_cdf[selection_l])
    y_pdf = y_pdf * (y_cdf[1:] - y_cdf[:-1])
    y_pdf = pad_zeros(xp, y_pdf, 1, quantile_axis)

    return x_pdf, y_pdf
Ejemplo n.º 10
0
def cdf(y_pred, quantiles, quantile_axis=1):
    """
    Calculates the cumulative distribution function (CDF) from predicted
    quantiles.

    Args:
        y_pred: Array containing a range of predicted quantiles. The array
            is expected to contain the quantiles along the axis given by
            ``quantile_axis.``
        quantiles: Array containing quantile fraction corresponding to the
            the predicted quantiles.
        quantile_axis: The index of the axis f the ``y_pred`` array, along
            which the quantiles are found.

    Returns:
        Tuple ``(x_cdf, y_cdf)`` of x and corresponding y-values of the CDF
        corresponding to quantiles given by ``y_pred``.

    Raises:

        InvalidArrayTypeException: When the data is provided neither as
             numpy array nor as torch tensor.

        InvalidDimensionException: When the provided predicted quantiles do
             not match the provided number of quantiles.
    """
    if len(y_pred.shape) == 1:
        quantile_axis = 0
    if y_pred.shape[quantile_axis] != len(quantiles):
        raise InvalidDimensionException(
            "Dimensions of the provided array 'y_pred' do not match the"
            "provided number of quantiles.")

    output_shape = list(y_pred.shape)
    xp = get_array_module(y_pred)

    y_cdf = quantiles
    y_cdf = concatenate(xp, [xp.zeros(1), y_cdf, xp.ones(1)], 0)

    selection = [slice(0, None)] * len(y_pred.shape)
    selection_c = copy(selection)
    selection_c[quantile_axis] = 0
    selection_c = tuple(selection_c)
    selection_r = copy(selection)
    selection_r[quantile_axis] = 1
    selection_r = tuple(selection_r)
    dy = (y_pred[selection_r] - y_pred[selection_c])
    dy /= (quantiles[1] - quantiles[0])
    x_cdf_l = y_pred[selection_c] - 2.0 * quantiles[0] * dy
    x_cdf_l = expand_dims(xp, x_cdf_l, quantile_axis)

    selection_l = copy(selection)
    selection_l[quantile_axis] = -2
    selection_l = tuple(selection_l)
    selection_c = copy(selection)
    selection_c[quantile_axis] = -1
    selection_c = tuple(selection_c)
    dy = (y_pred[selection_c] - y_pred[selection_l])
    dy /= (quantiles[-1] - quantiles[-2])
    x_cdf_r = y_pred[selection_c] + 2.0 * (1.0 - quantiles[-1]) * dy
    x_cdf_r = expand_dims(xp, x_cdf_r, quantile_axis)

    x_cdf = concatenate(xp, [x_cdf_l, y_pred, x_cdf_r], quantile_axis)

    return x_cdf, y_cdf
Ejemplo n.º 11
0
def pdf_binned(y_pred, quantiles, bins, quantile_axis=1):
    """
    Calculate binned representation of the posterior probability density
    function (PDF).

    The binned PDF is simple calculated by linearly interpolating the
    piece-wise linear PDF computed using the :py:meth`pdf` method.

    Args:
        y_pred: Rank-k Tensor containing the predicted quantiles along the
            quantile axis.
        quantiles: The quantile fractions corresponding to the predicted
            quantiles in y_pred.
        bins: Rank-1 tensor containing the ``n_bins`` boundaries for the bins
            to use to bin the PDF.
        quantile_axis: The axis of y_pred along which the predicted
            quantiles are located.

    Returns:
        Rank-k tensor with ``n_bins - 1`` elements along ``quantile_axis``
        containing the probability of the result to fall between the
        corresponding bin edges.
    """
    if len(y_pred.shape) == 1:
        quantile_axis = 0

    xp = get_array_module(y_pred)
    n = len(y_pred.shape)
    x_cdf, y_cdf = cdf(y_pred, quantiles, quantile_axis=quantile_axis)

    y_cdf_shape = [1] * n
    y_cdf_shape[quantile_axis] = -1
    y_cdf = reshape(xp, y_cdf, y_cdf_shape)

    selection_l = [slice(0, None)] * n
    selection_l[quantile_axis] = slice(0, -1)
    selection_l = tuple(selection_l)
    selection_r = [slice(0, None)] * n
    selection_r[quantile_axis] = slice(1, None)
    selection_r = tuple(selection_r)

    selection_le = [slice(0, None)] * n
    selection_le[quantile_axis] = 0
    selection_le = tuple(selection_le)

    selection_re = [slice(0, None)] * n
    selection_re[quantile_axis] = -1
    selection_re = tuple(selection_re)

    y_pdf_binned = []

    #
    # Interpolate CDF for leftmost bin boundary.
    #
    b_l = bins[0]
    mask_r = as_type(xp, (x_cdf[selection_r] >= b_l), y_cdf)
    mask_l = as_type(xp, (x_cdf[selection_l] < b_l), y_cdf)
    mask = mask_l * mask_r

    mask_xr = as_type(xp, xp.sum(mask_r, quantile_axis) == 0.0, mask_r)
    mask_xl = as_type(xp, xp.sum(mask_l, quantile_axis) == 0.0, mask_l)

    x_cdf_l = xp.sum(x_cdf[selection_l] * mask, quantile_axis)
    x_cdf_r = xp.sum(x_cdf[selection_r] * mask, quantile_axis)
    d = (x_cdf_r - x_cdf_l) + (1.0 - xp.sum(mask, quantile_axis))
    w_cdf_l = (x_cdf_r - b_l) / d
    w_cdf_r = (b_l - x_cdf_l) / d

    y_cdf_l = (
        xp.sum(mask * y_cdf[selection_l] * mask, quantile_axis) * w_cdf_l +
        xp.sum(mask * y_cdf[selection_r] * mask, quantile_axis) * w_cdf_r +
        mask_xl * y_cdf[selection_le] + mask_xr * y_cdf[selection_re])

    for i in range(len(bins) - 1):

        b_r = bins[i + 1]

        #
        # Interpolate CDF for right bin boundary.
        #
        mask_r = as_type(xp, (x_cdf[selection_r] >= b_r), y_cdf)
        mask_l = as_type(xp, (x_cdf[selection_l] < b_r), y_cdf)
        mask = mask_l * mask_r

        mask_xr = as_type(xp, xp.sum(mask_r, quantile_axis) == 0.0, mask_r)
        mask_xl = as_type(xp, xp.sum(mask_l, quantile_axis) == 0.0, mask_l)

        x_cdf_l = xp.sum(x_cdf[selection_l] * mask, quantile_axis)
        x_cdf_r = xp.sum(x_cdf[selection_r] * mask, quantile_axis)
        d = (x_cdf_r - x_cdf_l) + (1.0 - xp.sum(mask, quantile_axis))
        w_cdf_l = (x_cdf_r - b_r) / d
        w_cdf_r = (b_r - x_cdf_l) / d

        y_cdf_r = (
            xp.sum(mask * y_cdf[selection_l] * mask, quantile_axis) * w_cdf_l +
            xp.sum(mask * y_cdf[selection_r] * mask, quantile_axis) * w_cdf_r +
            mask_xl * y_cdf[selection_le] + mask_xr * y_cdf[selection_re])

        dy_cdf = expand_dims(xp, y_cdf_r - y_cdf_l, quantile_axis)
        y_pdf_binned.append(dy_cdf / (b_r - b_l))
        y_cdf_l = y_cdf_r
        b_l = b_r

    return concatenate(xp, y_pdf_binned, quantile_axis)
Ejemplo n.º 12
0
def correct_a_priori(y_pred, quantiles, r, quantile_axis=1):
    """
    Correct predicted quantiles for a priori.

    Args:
        y_pred: Rank-k tensor containing the predicted quantiles along
            the axis given by 'quantile_axis'.
        quantiles: Rank-1 tensor containing the quantile fractions that
            correspond to the predicted quantiles.
        r: A priori density ratio to use to correct the observations.
        quantile_axis: The axis along which the quantile are oriented
            in 'y_pred'.
    """
    if len(y_pred.shape) == 1:
        quantile_axis = 0
    xp = get_array_module(y_pred)
    n_dims = len(y_pred.shape)
    x_pdf, y_pdf = pdf(y_pred, quantiles, quantile_axis=quantile_axis)

    selection = [slice(0, None)] * len(y_pred.shape)
    selection_c = copy(selection)
    selection_c[quantile_axis] = 0
    selection_c = tuple(selection_c)
    selection_r = copy(selection)
    selection_r[quantile_axis] = 1
    selection_r = tuple(selection_r)
    dy = y_pred[selection_r] - y_pred[selection_c]
    dy /= quantiles[1] - quantiles[0]
    x_cdf_l = y_pred[selection_c] - 2.0 * quantiles[0] * dy
    x_cdf_l = expand_dims(xp, x_cdf_l, quantile_axis)

    selection_l = copy(selection)
    selection_l[quantile_axis] = -2
    selection_l = tuple(selection_l)
    selection_c = copy(selection)
    selection_c[quantile_axis] = -1
    selection_c = tuple(selection_c)
    dy = y_pred[selection_c] - y_pred[selection_l]
    dy /= quantiles[-1] - quantiles[-2]
    x_cdf_r = y_pred[selection_c] + 2.0 * (1.0 - quantiles[-1]) * dy
    x_cdf_r = expand_dims(xp, x_cdf_r, quantile_axis)

    x_cdf = concatenate(xp, [x_cdf_l, y_pred, x_cdf_r], quantile_axis)

    selection_l = [slice(0, None)] * n_dims
    selection_l[quantile_axis] = slice(0, -1)
    selection_l = tuple(selection_l)
    selection_r = [slice(0, None)] * n_dims
    selection_r[quantile_axis] = slice(1, None)
    selection_r = tuple(selection_r)

    x_index = [slice(0, None)] * n_dims
    x_index[quantile_axis] = 0

    y_pdf_new = r(x_pdf, dist_axis=quantile_axis) * y_pdf

    selection = [slice(0, None)] * n_dims
    selection[quantile_axis] = slice(1, -1)
    selection = tuple(selection)
    y_cdf_new = cumtrapz(xp, y_pdf_new[selection], x_cdf, quantile_axis)

    selection = [slice(0, None)] * n_dims
    selection[quantile_axis] = slice(-1, None)
    selection = tuple(selection)
    y_cdf_new = y_cdf_new / y_cdf_new[selection]

    x_cdf_l = x_cdf[selection_l]
    x_cdf_r = x_cdf[selection_r]
    y_cdf_new_l = y_cdf_new[selection_l]
    y_cdf_new_r = y_cdf_new[selection_r]

    y_pred_new = []

    for i in range(0, len(quantiles)):
        q = quantiles[i]

        mask = as_type(xp, (y_cdf_new_l < q) * (y_cdf_new_r >= q), x_cdf_l)
        y_new = x_cdf_l * (y_cdf_new_r - q) * mask
        y_new += x_cdf_r * (q - y_cdf_new_l) * mask
        y_new /= mask * (y_cdf_new_r - y_cdf_new_l) + (1.0 - mask)
        y_new = expand_dims(xp, y_new.sum(quantile_axis), quantile_axis)

        y_pred_new.append(y_new)

    y_pred_new = concatenate(xp, y_pred_new, quantile_axis)
    return y_pred_new