Example #1
0
    def test_searchsorted(self):
        bin_locations = torch.linspace(0, 1, 10)  # 9 bins == 10 locations

        left_boundaries = bin_locations[:-1]
        right_boundaries = bin_locations[:-1] + 0.1
        mid_points = bin_locations[:-1] + 0.05

        for inputs in [left_boundaries, right_boundaries, mid_points]:
            with self.subTest(inputs=inputs):
                idx = torchutils.searchsorted(bin_locations[None, :], inputs)
                self.assertEqual(idx, torch.arange(0, 9))
Example #2
0
 def test_searchsorted_arbitrary_shape(self):
     shape = [2, 3, 4]
     bin_locations = torch.linspace(0, 1, 10).repeat(*shape, 1)
     inputs = torch.rand(*shape)
     idx = torchutils.searchsorted(bin_locations, inputs)
     self.assertEqual(idx.shape, inputs.shape)
Example #3
0
def rational_quadratic_spline(
    inputs,
    unnormalized_widths,
    unnormalized_heights,
    unnormalized_derivatives,
    inverse=False,
    tail_bound=None,
    left=0.0,
    right=1.0,
    bottom=0.0,
    top=1.0,
    min_bin_width=DEFAULT_MIN_BIN_WIDTH,
    min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
    min_derivative=DEFAULT_MIN_DERIVATIVE,
):
    log_det_contr = [0]
    if tail_bound:
        shift_op = PointwiseAffineTransform(1 / 2, 1 / (2 * tail_bound))
        if not inverse:
            inputs, log_det_contr = shift_op.forward(inputs)

    if torch.min(inputs) < left or torch.max(inputs) > right:
        raise InputOutsideDomain()

    num_bins = unnormalized_widths.shape[-1]

    if min_bin_width * num_bins > 1.0:
        raise ValueError("Minimal bin width too large for the number of bins")
    if min_bin_height * num_bins > 1.0:
        raise ValueError("Minimal bin height too large for the number of bins")

    widths = F.softmax(unnormalized_widths, dim=-1)
    widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
    cumwidths = torch.cumsum(widths, dim=-1)
    cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
    cumwidths = (right - left) * cumwidths + left
    cumwidths[..., 0] = left
    cumwidths[..., -1] = right
    widths = cumwidths[..., 1:] - cumwidths[..., :-1]

    derivatives = min_derivative + F.softplus(unnormalized_derivatives)

    heights = F.softmax(unnormalized_heights, dim=-1)
    heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
    cumheights = torch.cumsum(heights, dim=-1)
    cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
    cumheights = (top - bottom) * cumheights + bottom
    cumheights[..., 0] = bottom
    cumheights[..., -1] = top
    heights = cumheights[..., 1:] - cumheights[..., :-1]

    if inverse:
        bin_idx = torchutils.searchsorted(cumheights, inputs)[..., None]
    else:
        bin_idx = torchutils.searchsorted(cumwidths, inputs)[..., None]

    input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
    input_bin_widths = widths.gather(-1, bin_idx)[..., 0]

    input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
    delta = heights / widths
    input_delta = delta.gather(-1, bin_idx)[..., 0]

    input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
    input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[...,
                                                                          0]

    input_heights = heights.gather(-1, bin_idx)[..., 0]

    if inverse:
        a = (inputs - input_cumheights) * (
            input_derivatives + input_derivatives_plus_one - 2 *
            input_delta) + input_heights * (input_delta - input_derivatives)
        b = input_heights * input_derivatives - (inputs - input_cumheights) * (
            input_derivatives + input_derivatives_plus_one - 2 * input_delta)
        c = -input_delta * (inputs - input_cumheights)

        discriminant = b.pow(2) - 4 * a * c
        assert (discriminant >= 0).all()

        root = (2 * c) / (-b - torch.sqrt(discriminant))
        # root = (- b + torch.sqrt(discriminant)) / (2 * a)
        outputs = root * input_bin_widths + input_cumwidths

        theta_one_minus_theta = root * (1 - root)
        denominator = input_delta + (
            (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
            * theta_one_minus_theta)
        derivative_numerator = input_delta.pow(2) * (
            input_derivatives_plus_one * root.pow(2) +
            2 * input_delta * theta_one_minus_theta + input_derivatives *
            (1 - root).pow(2))
        logabsdet = torch.log(
            derivative_numerator) - 2 * torch.log(denominator)

        if tail_bound:
            outputs, log_det_contr = shift_op.inverse(outputs)

        return outputs, -logabsdet + log_det_contr[0]
    else:

        theta = (inputs - input_cumwidths) / input_bin_widths
        theta_one_minus_theta = theta * (1 - theta)

        numerator = input_heights * (input_delta * theta.pow(2) +
                                     input_derivatives * theta_one_minus_theta)
        denominator = input_delta + (
            (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
            * theta_one_minus_theta)
        outputs = input_cumheights + numerator / denominator

        derivative_numerator = input_delta.pow(2) * (
            input_derivatives_plus_one * theta.pow(2) +
            2 * input_delta * theta_one_minus_theta + input_derivatives *
            (1 - theta).pow(2))
        logabsdet = torch.log(
            derivative_numerator) - 2 * torch.log(denominator)

        return outputs, logabsdet + log_det_contr[0]
Example #4
0
def rational_quadratic_spline2(inputs,
                               unnormalized_widths,
                               unnormalized_heights,
                               unnormalized_derivatives,
                               inverse=False,
                               left=0.0,
                               right=1.0,
                               bottom=0.0,
                               top=1.0,
                               min_bin_width=DEFAULT_MIN_BIN_WIDTH,
                               min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
                               min_derivative=DEFAULT_MIN_DERIVATIVE):

    # for (itr,(l,r)) in enumerate(zip(left,right)):

    #     if torch.min(inputs[:,itr]) < l or torch.max(inputs[:,itr]) > r:

    #         if torch.min(inputs[:,itr]) < l:
    #              print("lower bound of feature "+str(itr)+ " violated")
    #         else:
    #              print("upper bound of feature "+str(itr)+ " violated")
    #         raise InputOutsideDomain()

    num_bins = unnormalized_widths.shape[-1]

    if min_bin_width * num_bins > 1.0:
        raise ValueError("Minimal bin width too large for the number of bins")
    if min_bin_height * num_bins > 1.0:
        raise ValueError("Minimal bin height too large for the number of bins")

    widths = F.softmax(unnormalized_widths, dim=-1)
    widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
    cumwidths = torch.cumsum(widths, dim=-1)
    cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
    cumwidths = (right - left).unsqueeze(0).unsqueeze(
        2) * cumwidths + left.unsqueeze(0).unsqueeze(2)
    cumwidths[..., 0] = left
    cumwidths[..., -1] = right
    widths = cumwidths[..., 1:] - cumwidths[..., :-1]

    derivatives = min_derivative + F.softplus(unnormalized_derivatives)

    heights = F.softmax(unnormalized_heights, dim=-1)
    heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
    cumheights = torch.cumsum(heights, dim=-1)
    cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
    cumheights = (right - left).unsqueeze(0).unsqueeze(
        2) * cumheights + left.unsqueeze(0).unsqueeze(2)
    cumheights[..., 0] = left
    cumheights[..., -1] = right
    heights = cumheights[..., 1:] - cumheights[..., :-1]

    if inverse:
        bin_idx = torchutils.searchsorted(cumheights, inputs)[..., None]

    else:
        bin_idx = torchutils.searchsorted(cumwidths, inputs)[..., None]

    input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
    input_bin_widths = widths.gather(-1, bin_idx)[..., 0]

    input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
    delta = heights / widths
    input_delta = delta.gather(-1, bin_idx)[..., 0]

    input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
    input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[...,
                                                                          0]

    input_heights = heights.gather(-1, bin_idx)[..., 0]

    if inverse:
        a = (inputs - input_cumheights) * (
            input_derivatives + input_derivatives_plus_one - 2 *
            input_delta) + input_heights * (input_delta - input_derivatives)
        b = input_heights * input_derivatives - (inputs - input_cumheights) * (
            input_derivatives + input_derivatives_plus_one - 2 * input_delta)
        c = -input_delta * (inputs - input_cumheights)

        discriminant = b.pow(2) - 4 * a * c
        assert (discriminant >= 0).all()

        root = (2 * c) / (-b - torch.sqrt(discriminant))
        # root = (- b + torch.sqrt(discriminant)) / (2 * a)
        outputs = root * input_bin_widths + input_cumwidths

        theta_one_minus_theta = root * (1 - root)
        denominator = input_delta + (
            (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
            * theta_one_minus_theta)
        derivative_numerator = input_delta.pow(2) * (
            input_derivatives_plus_one * root.pow(2) +
            2 * input_delta * theta_one_minus_theta + input_derivatives *
            (1 - root).pow(2))
        logabsdet = torch.log(
            derivative_numerator) - 2 * torch.log(denominator)

        return outputs, -logabsdet
    else:
        theta = (inputs - input_cumwidths) / input_bin_widths
        theta_one_minus_theta = theta * (1 - theta)

        numerator = input_heights * (input_delta * theta.pow(2) +
                                     input_derivatives * theta_one_minus_theta)
        denominator = input_delta + (
            (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
            * theta_one_minus_theta)

        outputs = input_cumheights + numerator / denominator

        derivative_numerator = input_delta.pow(2) * (
            input_derivatives_plus_one * theta.pow(2) +
            2 * input_delta * theta_one_minus_theta + input_derivatives *
            (1 - theta).pow(2))
        logabsdet = torch.log(
            derivative_numerator) - 2 * torch.log(denominator)

        return outputs, logabsdet
Example #5
0
def quadratic_spline(
    inputs,
    unnormalized_widths,
    unnormalized_heights,
    inverse=False,
    left=0.0,
    right=1.0,
    bottom=0.0,
    top=1.0,
    min_bin_width=DEFAULT_MIN_BIN_WIDTH,
    min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
):
    if torch.min(inputs) < left or torch.max(inputs) > right:
        raise InputOutsideDomain()

    if inverse:
        inputs = (inputs - bottom) / (top - bottom)
    else:
        inputs = (inputs - left) / (right - left)

    num_bins = unnormalized_widths.shape[-1]

    if min_bin_width * num_bins > 1.0:
        raise ValueError("Minimal bin width too large for the number of bins")
    if min_bin_height * num_bins > 1.0:
        raise ValueError("Minimal bin height too large for the number of bins")

    widths = F.softmax(unnormalized_widths, dim=-1)
    widths = min_bin_width + (1 - min_bin_width * num_bins) * widths

    unnorm_heights_exp = F.softplus(unnormalized_heights) + 1e-3

    if unnorm_heights_exp.shape[-1] == num_bins - 1:
        # Set boundary heights s.t. after normalization they are exactly 1.
        first_widths = 0.5 * widths[..., 0]
        last_widths = 0.5 * widths[..., -1]
        numerator = (
            0.5 * first_widths * unnorm_heights_exp[..., 0] +
            0.5 * last_widths * unnorm_heights_exp[..., -1] + torch.sum(
                ((unnorm_heights_exp[..., :-1] + unnorm_heights_exp[..., 1:]) /
                 2) * widths[..., 1:-1],
                dim=-1,
            ))
        constant = numerator / (1 - 0.5 * first_widths - 0.5 * last_widths)
        constant = constant[..., None]
        unnorm_heights_exp = torch.cat(
            [constant, unnorm_heights_exp, constant], dim=-1)

    unnormalized_area = torch.sum(
        ((unnorm_heights_exp[..., :-1] + unnorm_heights_exp[..., 1:]) / 2) *
        widths,
        dim=-1,
    )[..., None]
    heights = unnorm_heights_exp / unnormalized_area
    heights = min_bin_height + (1 - min_bin_height) * heights

    bin_left_cdf = torch.cumsum(
        ((heights[..., :-1] + heights[..., 1:]) / 2) * widths, dim=-1)
    bin_left_cdf[..., -1] = 1.0
    bin_left_cdf = F.pad(bin_left_cdf, pad=(1, 0), mode="constant", value=0.0)

    bin_locations = torch.cumsum(widths, dim=-1)
    bin_locations[..., -1] = 1.0
    bin_locations = F.pad(bin_locations,
                          pad=(1, 0),
                          mode="constant",
                          value=0.0)

    if inverse:
        bin_idx = torchutils.searchsorted(bin_left_cdf, inputs)[..., None]
    else:
        bin_idx = torchutils.searchsorted(bin_locations, inputs)[..., None]

    input_bin_locations = bin_locations.gather(-1, bin_idx)[..., 0]
    input_bin_widths = widths.gather(-1, bin_idx)[..., 0]

    input_left_cdf = bin_left_cdf.gather(-1, bin_idx)[..., 0]

    input_left_heights = heights.gather(-1, bin_idx)[..., 0]
    input_right_heights = heights.gather(-1, bin_idx + 1)[..., 0]

    a = 0.5 * (input_right_heights - input_left_heights) * input_bin_widths
    b = input_left_heights * input_bin_widths
    c = input_left_cdf

    if inverse:
        c_ = c - inputs
        alpha = (-b + torch.sqrt(b.pow(2) - 4 * a * c_)) / (2 * a)
        outputs = alpha * input_bin_widths + input_bin_locations
        outputs = torch.clamp(outputs, 0, 1)
        logabsdet = -torch.log(
            (alpha *
             (input_right_heights - input_left_heights) + input_left_heights))
    else:
        alpha = (inputs - input_bin_locations) / input_bin_widths
        outputs = a * alpha.pow(2) + b * alpha + c
        outputs = torch.clamp(outputs, 0, 1)
        logabsdet = torch.log(
            (alpha * (input_right_heights - input_left_heights) +
             input_left_heights))

    if inverse:
        outputs = outputs * (right - left) + left
    else:
        outputs = outputs * (top - bottom) + bottom

    return outputs, logabsdet
Example #6
0
def linear_spline(inputs,
                  unnormalized_pdf,
                  inverse=False,
                  left=0.0,
                  right=1.0,
                  bottom=0.0,
                  top=1.0):
    """
    Reference:
    > Müller et al., Neural Importance Sampling, arXiv:1808.03856, 2018.
    """
    if torch.min(inputs) < left or torch.max(inputs) > right:
        raise InputOutsideDomain()

    if inverse:
        inputs = (inputs - bottom) / (top - bottom)
    else:
        inputs = (inputs - left) / (right - left)

    num_bins = unnormalized_pdf.size(-1)

    pdf = F.softmax(unnormalized_pdf, dim=-1)

    cdf = torch.cumsum(pdf, dim=-1)
    cdf[..., -1] = 1.0
    cdf = F.pad(cdf, pad=(1, 0), mode="constant", value=0.0)

    if inverse:
        inv_bin_idx = torchutils.searchsorted(cdf, inputs)

        bin_boundaries = (torch.linspace(
            0, 1, num_bins + 1).view([1] * inputs.dim() + [-1]).expand(
                *inputs.shape, -1))

        slopes = (cdf[..., 1:] - cdf[..., :-1]) / (bin_boundaries[..., 1:] -
                                                   bin_boundaries[..., :-1])
        offsets = cdf[..., 1:] - slopes * bin_boundaries[..., 1:]

        inv_bin_idx = inv_bin_idx.unsqueeze(-1)
        input_slopes = slopes.gather(-1, inv_bin_idx)[..., 0]
        input_offsets = offsets.gather(-1, inv_bin_idx)[..., 0]

        outputs = (inputs - input_offsets) / input_slopes
        outputs = torch.clamp(outputs, 0, 1)

        logabsdet = -torch.log(input_slopes)
    else:
        bin_pos = inputs * num_bins

        bin_idx = torch.floor(bin_pos).long()
        bin_idx[bin_idx >= num_bins] = num_bins - 1

        alpha = bin_pos - bin_idx.float()

        input_pdfs = pdf.gather(-1, bin_idx[..., None])[..., 0]

        outputs = cdf.gather(-1, bin_idx[..., None])[..., 0]
        outputs += alpha * input_pdfs
        outputs = torch.clamp(outputs, 0, 1)

        bin_width = 1.0 / num_bins
        logabsdet = torch.log(input_pdfs) - np.log(bin_width)

    if inverse:
        outputs = outputs * (right - left) + left
    else:
        outputs = outputs * (top - bottom) + bottom

    return outputs, logabsdet
Example #7
0
def cubic_spline(
    inputs,
    unnormalized_widths,
    unnormalized_heights,
    unnorm_derivatives_left,
    unnorm_derivatives_right,
    inverse=False,
    left=0.0,
    right=1.0,
    bottom=0.0,
    top=1.0,
    min_bin_width=DEFAULT_MIN_BIN_WIDTH,
    min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
    eps=DEFAULT_EPS,
    quadratic_threshold=DEFAULT_QUADRATIC_THRESHOLD,
):
    """
    References:
    > Blinn, J. F. (2007). How to solve a cubic equation, part 5: Back to numerics. IEEE Computer
    Graphics and Applications, 27(3):78–89.
    """
    if torch.min(inputs) < left or torch.max(inputs) > right:
        raise InputOutsideDomain()

    num_bins = unnormalized_widths.shape[-1]

    if min_bin_width * num_bins > 1.0:
        raise ValueError("Minimal bin width too large for the number of bins")
    if min_bin_height * num_bins > 1.0:
        raise ValueError("Minimal bin height too large for the number of bins")

    if inverse:
        inputs = (inputs - bottom) / (top - bottom)
    else:
        inputs = (inputs - left) / (right - left)

    widths = F.softmax(unnormalized_widths, dim=-1)
    widths = min_bin_width + (1 - min_bin_width * num_bins) * widths

    cumwidths = torch.cumsum(widths, dim=-1)
    cumwidths[..., -1] = 1
    cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)

    heights = F.softmax(unnormalized_heights, dim=-1)
    heights = min_bin_height + (1 - min_bin_height * num_bins) * heights

    cumheights = torch.cumsum(heights, dim=-1)
    cumheights[..., -1] = 1
    cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)

    slopes = heights / widths
    min_something_1 = torch.min(torch.abs(slopes[..., :-1]),
                                torch.abs(slopes[..., 1:]))
    min_something_2 = (0.5 * (widths[..., 1:] * slopes[..., :-1] +
                              widths[..., :-1] * slopes[..., 1:]) /
                       (widths[..., :-1] + widths[..., 1:]))
    min_something = torch.min(min_something_1, min_something_2)

    derivatives_left = (torch.sigmoid(unnorm_derivatives_left) * 3 *
                        slopes[..., 0][..., None])
    derivatives_right = (torch.sigmoid(unnorm_derivatives_right) * 3 *
                         slopes[..., -1][..., None])

    derivatives = min_something * (torch.sign(slopes[..., :-1]) +
                                   torch.sign(slopes[..., 1:]))
    derivatives = torch.cat([derivatives_left, derivatives, derivatives_right],
                            dim=-1)

    a = (derivatives[..., :-1] + derivatives[..., 1:] -
         2 * slopes) / widths.pow(2)
    b = (3 * slopes - 2 * derivatives[..., :-1] -
         derivatives[..., 1:]) / widths
    c = derivatives[..., :-1]
    d = cumheights[..., :-1]

    if inverse:
        bin_idx = torchutils.searchsorted(cumheights, inputs)[..., None]
    else:
        bin_idx = torchutils.searchsorted(cumwidths, inputs)[..., None]

    inputs_a = a.gather(-1, bin_idx)[..., 0]
    inputs_b = b.gather(-1, bin_idx)[..., 0]
    inputs_c = c.gather(-1, bin_idx)[..., 0]
    inputs_d = d.gather(-1, bin_idx)[..., 0]

    input_left_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
    input_right_cumwidths = cumwidths.gather(-1, bin_idx + 1)[..., 0]

    if inverse:
        # Modified coefficients for solving the cubic.
        inputs_b_ = (inputs_b / inputs_a) / 3.0
        inputs_c_ = (inputs_c / inputs_a) / 3.0
        inputs_d_ = (inputs_d - inputs) / inputs_a

        delta_1 = -inputs_b_.pow(2) + inputs_c_
        delta_2 = -inputs_c_ * inputs_b_ + inputs_d_
        delta_3 = inputs_b_ * inputs_d_ - inputs_c_.pow(2)

        discriminant = 4.0 * delta_1 * delta_3 - delta_2.pow(2)

        depressed_1 = -2.0 * inputs_b_ * delta_1 + delta_2
        depressed_2 = delta_1

        three_roots_mask = (
            discriminant >= 0
        )  # Discriminant == 0 might be a problem in practice.
        one_root_mask = discriminant < 0

        outputs = torch.zeros_like(inputs)

        # Deal with one root cases.

        p = torchutils.cbrt((-depressed_1[one_root_mask] +
                             torch.sqrt(-discriminant[one_root_mask])) / 2.0)
        q = torchutils.cbrt((-depressed_1[one_root_mask] -
                             torch.sqrt(-discriminant[one_root_mask])) / 2.0)

        outputs[one_root_mask] = ((p + q) - inputs_b_[one_root_mask] +
                                  input_left_cumwidths[one_root_mask])

        # Deal with three root cases.

        theta = torch.atan2(torch.sqrt(discriminant[three_roots_mask]),
                            -depressed_1[three_roots_mask])
        theta /= 3.0

        cubic_root_1 = torch.cos(theta)
        cubic_root_2 = torch.sin(theta)

        root_1 = cubic_root_1
        root_2 = -0.5 * cubic_root_1 - 0.5 * math.sqrt(3) * cubic_root_2
        root_3 = -0.5 * cubic_root_1 + 0.5 * math.sqrt(3) * cubic_root_2

        root_scale = 2 * torch.sqrt(-depressed_2[three_roots_mask])
        root_shift = (-inputs_b_[three_roots_mask] +
                      input_left_cumwidths[three_roots_mask])

        root_1 = root_1 * root_scale + root_shift
        root_2 = root_2 * root_scale + root_shift
        root_3 = root_3 * root_scale + root_shift

        root1_mask = ((input_left_cumwidths[three_roots_mask] - eps) <
                      root_1).float()
        root1_mask *= (
            root_1 < (input_right_cumwidths[three_roots_mask] + eps)).float()

        root2_mask = ((input_left_cumwidths[three_roots_mask] - eps) <
                      root_2).float()
        root2_mask *= (
            root_2 < (input_right_cumwidths[three_roots_mask] + eps)).float()

        root3_mask = ((input_left_cumwidths[three_roots_mask] - eps) <
                      root_3).float()
        root3_mask *= (
            root_3 < (input_right_cumwidths[three_roots_mask] + eps)).float()

        roots = torch.stack([root_1, root_2, root_3], dim=-1)
        masks = torch.stack([root1_mask, root2_mask, root3_mask], dim=-1)
        mask_index = torch.argsort(masks, dim=-1,
                                   descending=True)[..., 0][..., None]
        outputs[three_roots_mask] = torch.gather(roots,
                                                 dim=-1,
                                                 index=mask_index).view(-1)

        # Deal with a -> 0 (almost quadratic) cases.

        quadratic_mask = inputs_a.abs() < quadratic_threshold
        a = inputs_b[quadratic_mask]
        b = inputs_c[quadratic_mask]
        c = inputs_d[quadratic_mask] - inputs[quadratic_mask]
        alpha = (-b + torch.sqrt(b.pow(2) - 4 * a * c)) / (2 * a)
        outputs[quadratic_mask] = alpha + input_left_cumwidths[quadratic_mask]

        shifted_outputs = outputs - input_left_cumwidths
        logabsdet = -torch.log((3 * inputs_a * shifted_outputs.pow(2) +
                                2 * inputs_b * shifted_outputs + inputs_c))
    else:
        shifted_inputs = inputs - input_left_cumwidths
        outputs = (inputs_a * shifted_inputs.pow(3) +
                   inputs_b * shifted_inputs.pow(2) +
                   inputs_c * shifted_inputs + inputs_d)

        logabsdet = torch.log((3 * inputs_a * shifted_inputs.pow(2) +
                               2 * inputs_b * shifted_inputs + inputs_c))

    if inverse:
        outputs = outputs * (right - left) + left
    else:
        outputs = outputs * (top - bottom) + bottom

    return outputs, logabsdet