def combine_tax_scales(
        node: ParameterNodeAtInstant,
        combined_tax_scales: typing.Optional[taxscales.MarginalRateTaxScale] = None,
        ) -> typing.Optional[taxscales.MarginalRateTaxScale]:
    """
    Combine all the MarginalRateTaxScales in the node into a single
    MarginalRateTaxScale.
    """

    name = next(iter(node or []), None)

    if name is None:
        return combined_tax_scales

    if combined_tax_scales is None:
        combined_tax_scales = taxscales.MarginalRateTaxScale(name = name)
        combined_tax_scales.add_bracket(0, 0)

    for child_name in node:
        child = node[child_name]

        if isinstance(child, taxscales.MarginalRateTaxScale):
            combined_tax_scales.add_tax_scale(child)

        else:
            log.info(
                f"Skipping {child_name} with value {child} "
                "because it is not a marginal rate tax scale",
                )

    return combined_tax_scales
def test_marginal_rates():
    tax_base = numpy.array([0, 10, 50, 125, 250])
    tax_scale = taxscales.MarginalRateTaxScale()
    tax_scale.add_bracket(0, 0)
    tax_scale.add_bracket(100, 0.1)
    tax_scale.add_bracket(200, 0.2)

    result = tax_scale.marginal_rates(tax_base)

    tools.assert_near(result, [0, 0, 0, 0.1, 0.2])
def test_inverse():
    gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6])
    tax_scale = taxscales.MarginalRateTaxScale()
    tax_scale.add_bracket(0, 0)
    tax_scale.add_bracket(1, 0)
    tax_scale.add_bracket(3, 0)
    net_tax_base = gross_tax_base - tax_scale.calc(gross_tax_base)

    result = tax_scale.inverse()

    tools.assert_near(result.calc(net_tax_base), gross_tax_base, 1e-15)
def test_scale_tax_scales():
    tax_base = numpy.array([1, 2, 3])
    tax_base_scale = 12.345
    scaled_tax_base = tax_base * tax_base_scale
    tax_scale = taxscales.MarginalRateTaxScale()
    tax_scale.add_bracket(1, 0)
    tax_scale.add_bracket(2, 0)
    tax_scale.add_bracket(3, 0)

    result = tax_scale.scale_tax_scales(tax_base_scale)

    tools.assert_near(result.thresholds, scaled_tax_base)
def test_calc_without_round():
    tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005])
    tax_scale = taxscales.MarginalRateTaxScale()
    tax_scale.add_bracket(0, 0)
    tax_scale.add_bracket(100, 0.1)

    result = tax_scale.calc(tax_base)

    tools.assert_near(
        result,
        [10, 10.02, 10.0002, 10.06, 10.0006, 10.05, 10.0005],
        absolute_error_margin = 1e-10,
        )
def test_calc_when_round_is_1():
    tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005])
    tax_scale = taxscales.MarginalRateTaxScale()
    tax_scale.add_bracket(0, 0)
    tax_scale.add_bracket(100, 0.1)

    result = tax_scale.calc(tax_base, round_base_decimals = 1)

    tools.assert_near(
        result,
        [10, 10.0, 10.0, 10.1, 10.0, 10, 10.0],
        absolute_error_margin = 1e-10,
        )
def test_calc():
    tax_base = numpy.array([1, 1.5, 2, 2.5, 3.0, 4.0])
    tax_scale = taxscales.MarginalRateTaxScale()
    tax_scale.add_bracket(0, 0)
    tax_scale.add_bracket(1, 0.1)
    tax_scale.add_bracket(2, 0.2)
    tax_scale.add_bracket(3, 0)

    result = tax_scale.calc(tax_base)

    tools.assert_near(
        result,
        [0, 0.05, 0.1, 0.2, 0.3, 0.3],
        absolute_error_margin = 1e-10,
        )
예제 #8
0
    def _get_at_instant(self, instant):
        brackets = [
            bracket.get_at_instant(instant) for bracket in self.brackets
        ]

        if self.metadata.get('type') == 'single_amount':
            scale = taxscales.SingleAmountTaxScale()
            for bracket in brackets:
                if 'amount' in bracket._children and 'threshold' in bracket._children:
                    amount = bracket.amount
                    threshold = bracket.threshold
                    scale.add_bracket(threshold, amount)
            return scale
        elif any('amount' in bracket._children for bracket in brackets):
            scale = taxscales.MarginalAmountTaxScale()
            for bracket in brackets:
                if 'amount' in bracket._children and 'threshold' in bracket._children:
                    amount = bracket.amount
                    threshold = bracket.threshold
                    scale.add_bracket(threshold, amount)
            return scale
        elif any('average_rate' in bracket._children for bracket in brackets):
            scale = taxscales.LinearAverageRateTaxScale()

            for bracket in brackets:
                if 'base' in bracket._children:
                    base = bracket.base
                else:
                    base = 1.
                if 'average_rate' in bracket._children and 'threshold' in bracket._children:
                    average_rate = bracket.average_rate
                    threshold = bracket.threshold
                    scale.add_bracket(threshold, average_rate * base)
            return scale
        else:
            scale = taxscales.MarginalRateTaxScale()

            for bracket in brackets:
                if 'base' in bracket._children:
                    base = bracket.base
                else:
                    base = 1.
                if 'rate' in bracket._children and 'threshold' in bracket._children:
                    rate = bracket.rate
                    threshold = bracket.threshold
                    scale.add_bracket(threshold, rate * base)
            return scale
def test_to_average():
    tax_base = numpy.array([1, 1.5, 2, 2.5])
    tax_scale = taxscales.MarginalRateTaxScale()
    tax_scale.add_bracket(0, 0)
    tax_scale.add_bracket(1, 0.1)
    tax_scale.add_bracket(2, 0.2)

    result = tax_scale.to_average()

    # Note: assert_near doesn't work for inf.
    assert result.thresholds == [0, 1, 2, numpy.inf]
    assert result.rates, [0, 0, 0.05, 0.2]
    tools.assert_near(
        result.calc(tax_base),
        [0, 0.0375, 0.1, 0.125],
        absolute_error_margin = 1e-10,
        )
def test_inverse_scaled_marginal_tax_scales():
    gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6])
    gross_tax_base_scale = 12.345
    scaled_gross_tax_base = gross_tax_base * gross_tax_base_scale
    tax_scale = taxscales.MarginalRateTaxScale()
    tax_scale.add_bracket(0, 0)
    tax_scale.add_bracket(1, 0.1)
    tax_scale.add_bracket(3, 0.05)
    scaled_tax_scale = tax_scale.scale_tax_scales(gross_tax_base_scale)
    scaled_net_tax_base = (
        + scaled_gross_tax_base
        - scaled_tax_scale.calc(scaled_gross_tax_base)
        )

    result = scaled_tax_scale.inverse()

    tools.assert_near(result.calc(scaled_net_tax_base), scaled_gross_tax_base, 1e-13)
예제 #11
0
    def to_marginal(self) -> taxscales.MarginalRateTaxScale:
        marginal_tax_scale = taxscales.MarginalRateTaxScale(
            name = self.name,
            option = self.option,
            unit = self.unit,
            )

        previous_i = 0
        previous_threshold = 0

        for threshold, rate in zip(self.thresholds[1:], self.rates[1:]):
            if threshold != float("Inf"):
                i = rate * threshold
                marginal_tax_scale.add_bracket(
                    previous_threshold,
                    (i - previous_i) / (threshold - previous_threshold),
                    )
                previous_i = i
                previous_threshold = threshold

        marginal_tax_scale.add_bracket(previous_threshold, rate)

        return marginal_tax_scale