def test_proper_usage():
    """
    Test proper usage of GrowFactors object.
    """
    gfo = GrowFactors()
    pir = gfo.price_inflation_rates(2017, 2017)
    assert len(pir) == 1
    wgr = gfo.wage_growth_rates(2017, 2017)
    assert len(wgr) == 1
示例#2
0
def test_growfactors_csv_values():
    """
    Test numerical contents of growfactors.csv file.
    """
    gfo = GrowFactors()
    min_data_year = min(Records.PUFCSV_YEAR, Records.CPSCSV_YEAR)
    if min_data_year < Policy.JSON_START_YEAR:
        for gfname in GrowFactors.VALID_NAMES:
            val = gfo.factor_value(gfname, min_data_year)
            assert val == 1
示例#3
0
def test_proper_usage():
    """
    Test proper usage of GrowFactors object.
    """
    gfo = GrowFactors()
    pir = gfo.price_inflation_rates(2013, 2020)
    assert len(pir) == 8
    wgr = gfo.wage_growth_rates(2013, 2021)
    assert len(wgr) == 9
    val = gfo.factor_value('AWAGE', 2013)
    assert val > 1.0
示例#4
0
def test_update_and_apply_growdiff():
    gdiff = GrowDiff()
    # update GrowDiff instance
    diffs = {'AWAGE': {2014: 0.01, 2016: 0.02}}
    gdiff.update_growdiff(diffs)
    expected_wage_diffs = [0.00, 0.01, 0.01, 0.02, 0.02]
    extra_years = GrowDiff.DEFAULT_NUM_YEARS - len(expected_wage_diffs)
    expected_wage_diffs.extend([0.02] * extra_years)
    assert np.allclose(gdiff._AWAGE, expected_wage_diffs, atol=0.0, rtol=0.0)
    # apply growdiff to GrowFactors instance
    gf = GrowFactors()
    syr = GrowDiff.JSON_START_YEAR
    nyrs = GrowDiff.DEFAULT_NUM_YEARS
    lyr = syr + nyrs - 1
    pir_pre = gf.price_inflation_rates(syr, lyr)
    wgr_pre = gf.wage_growth_rates(syr, lyr)
    gfactors = GrowFactors()
    gdiff.apply_to(gfactors)
    pir_pst = gfactors.price_inflation_rates(syr, lyr)
    wgr_pst = gfactors.wage_growth_rates(syr, lyr)
    expected_wgr_pst = [
        wgr_pre[i] + expected_wage_diffs[i] for i in range(0, nyrs)
    ]
    assert np.allclose(pir_pre, pir_pst, atol=0.0, rtol=0.0)
    assert np.allclose(wgr_pst, expected_wgr_pst, atol=1.0e-9, rtol=0.0)
示例#5
0
def test_get_calculator_exception():
    iit_reform = {
        'II_rt1': {
            2017: 0.09
        },
        'II_rt2': {
            2017: 0.135
        },
        'II_rt3': {
            2017: 0.225
        },
        'II_rt4': {
            2017: 0.252
        },
        'II_rt5': {
            2017: 0.297
        },
        'II_rt6': {
            2017: 0.315
        },
        'II_rt7': {
            2017: 0.3564
        }
    }
    with pytest.raises(Exception):
        assert get_micro_data.get_calculator(
            baseline=False,
            calculator_start_year=TC_LAST_YEAR + 1,
            reform=iit_reform,
            data='cps',
            gfactors=GrowFactors(),
            records_start_year=CPS_START_YEAR)
示例#6
0
def test_get_micro_data_get_calculator():
    reform = {
        'II_rt1': {
            2017: 0.09
        },
        'II_rt2': {
            2017: 0.135
        },
        'II_rt3': {
            2017: 0.225
        },
        'II_rt4': {
            2017: 0.252
        },
        'II_rt5': {
            2017: 0.297
        },
        'II_rt6': {
            2017: 0.315
        },
        'II_rt7': {
            2017: 0.3564
        }
    }

    calc = get_calculator(baseline=False,
                          calculator_start_year=2017,
                          reform=reform,
                          data=TAXDATA,
                          gfactors=GrowFactors(),
                          weights=WEIGHTS,
                          records_start_year=CPS_START_YEAR)
    assert calc.current_year == 2017
示例#7
0
def test_get_calculator():
    iit_reform = {
        'II_rt1': {
            2017: 0.09
        },
        'II_rt2': {
            2017: 0.135
        },
        'II_rt3': {
            2017: 0.225
        },
        'II_rt4': {
            2017: 0.252
        },
        'II_rt5': {
            2017: 0.297
        },
        'II_rt6': {
            2017: 0.315
        },
        'II_rt7': {
            2017: 0.3564
        }
    }
    calc = get_micro_data.get_calculator(baseline=False,
                                         calculator_start_year=2017,
                                         reform=iit_reform,
                                         data='cps',
                                         gfactors=GrowFactors(),
                                         records_start_year=CPS_START_YEAR)
    assert calc.current_year == CPS_START_YEAR
示例#8
0
def test_get_calculator_cps(baseline, iit_reform):
    calc = get_micro_data.get_calculator(baseline=baseline,
                                         calculator_start_year=2017,
                                         reform=iit_reform,
                                         data='cps',
                                         gfactors=GrowFactors(),
                                         records_start_year=CPS_START_YEAR)
    assert calc.current_year == CPS_START_YEAR
示例#9
0
文件: tbi.py 项目: keiirizawa/OG-USA
def reform_warnings_errors(user_mods, using_puf):
    """
    The reform_warnings_errors function assumes user_mods is a dictionary
    returned by the Calculator.read_json_param_objects() function.

    This function returns a dictionary containing five STR:STR subdictionaries,
    where the dictionary keys are: 'policy', 'behavior', consumption',
    'growdiff_baseline' and 'growdiff_response'; and the subdictionaries are:
    {'warnings': '<empty-or-message(s)>', 'errors': '<empty-or-message(s)>'}.
    Note that non-policy parameters have no warnings, so the 'warnings'
    string for the non-policy parameters is always empty.
    """
    rtn_dict = {'policy': {'warnings': '', 'errors': ''},
                'behavior': {'warnings': '', 'errors': ''},
                'consumption': {'warnings': '', 'errors': ''},
                'growdiff_baseline': {'warnings': '', 'errors': ''},
                'growdiff_response': {'warnings': '', 'errors': ''}}
    # create GrowDiff objects
    gdiff_baseline = GrowDiff()
    try:
        gdiff_baseline.update_growdiff(user_mods['growdiff_baseline'])
    except ValueError as valerr_msg:
        rtn_dict['growdiff_baseline']['errors'] = valerr_msg.__str__()
    gdiff_response = GrowDiff()
    try:
        gdiff_response.update_growdiff(user_mods['growdiff_response'])
    except ValueError as valerr_msg:
        rtn_dict['growdiff_response']['errors'] = valerr_msg.__str__()
    # create Growfactors object
    growfactors = GrowFactors()
    gdiff_baseline.apply_to(growfactors)
    gdiff_response.apply_to(growfactors)
    # create Policy object
    pol = Policy(gfactors=growfactors)
    try:
        pol.implement_reform(user_mods['policy'],
                             print_warnings=False,
                             raise_errors=False)
        if using_puf:
            rtn_dict['policy']['warnings'] = pol.parameter_warnings
        rtn_dict['policy']['errors'] = pol.parameter_errors
    except ValueError as valerr_msg:
        rtn_dict['policy']['errors'] = valerr_msg.__str__()
    # create Behavior object
    behv = Behavior()
    try:
        behv.update_behavior(user_mods['behavior'])
    except ValueError as valerr_msg:
        rtn_dict['behavior']['errors'] = valerr_msg.__str__()
    # create Consumption object
    consump = Consumption()
    try:
        consump.update_consumption(user_mods['consumption'])
    except ValueError as valerr_msg:
        rtn_dict['consumption']['errors'] = valerr_msg.__str__()
    # return composite dictionary of warnings/errors
    return rtn_dict
def test_update_and_apply_growdiff():
    gdiff = GrowDiff()
    # update GrowDiff instance
    diffs = {
        'AWAGE': {2014: 0.01,
                  2016: 0.02}
    }
    gdiff.update_growdiff(diffs)
    expected_wage_diffs = [0.00, 0.01, 0.01, 0.02, 0.02]
    extra_years = GrowDiff.DEFAULT_NUM_YEARS - len(expected_wage_diffs)
    expected_wage_diffs.extend([0.02] * extra_years)
    assert np.allclose(gdiff._AWAGE, expected_wage_diffs, atol=0.0, rtol=0.0)
    # apply growdiff to GrowFactors instance
    gf = GrowFactors()
    syr = GrowDiff.JSON_START_YEAR
    nyrs = GrowDiff.DEFAULT_NUM_YEARS
    lyr = syr + nyrs - 1
    pir_pre = gf.price_inflation_rates(syr, lyr)
    wgr_pre = gf.wage_growth_rates(syr, lyr)
    gfactors = GrowFactors()
    gdiff.apply_to(gfactors)
    pir_pst = gfactors.price_inflation_rates(syr, lyr)
    wgr_pst = gfactors.wage_growth_rates(syr, lyr)
    expected_wgr_pst = [wgr_pre[i] + expected_wage_diffs[i]
                        for i in range(0, nyrs)]
    assert np.allclose(pir_pre, pir_pst, atol=0.0, rtol=0.0)
    assert np.allclose(wgr_pst, expected_wgr_pst, atol=1.0e-9, rtol=0.0)
示例#11
0
def test_get_micro_data_get_calculator():
    reform = {2017: {
        '_rate1': [0.09],
        '_rate2': [0.135],
        '_rate3': [0.225],
        '_rate4': [0.252]
        }}

    calc = get_calculator(baseline=False, calculator_start_year=2017,
                          reform=reform, data='pitSmallData.csv',
                          gfactors=GrowFactors(),
                          records_start_year=2017)
    assert calc.current_year == 2017
示例#12
0
def test_update_after_use():
    """
    Test of improper update after GrowFactors object has been used.
    """
    gfo = GrowFactors()
    gfo.price_inflation_rates(gfo.first_year, gfo.last_year)
    with pytest.raises(ValueError):
        gfo.update('AWAGE', 2013, 0.01)
def test_correct_Records_instantiation(gst_sample):
    rec1 = GSTRecords(data=gst_sample)
    assert rec1
    assert rec1.current_year == rec1.data_year
    rec1.set_current_year(rec1.data_year + 1)
    wghts_path = os.path.join(GSTRecords.CUR_PATH,
                              GSTRecords.GST_WEIGHTS_FILENAME)
    wghts_df = pd.read_csv(wghts_path)
    rec2 = GSTRecords(data=gst_sample,
                      gfactors=GrowFactors(),
                      weights=wghts_df,
                      start_year=GSTRecords.GSTCSV_YEAR)
    assert rec2
    assert rec2.current_year == rec2.data_year
示例#14
0
def test_correct_Records_instantiation(pit_subsample):
    rec1 = Records(data=pit_subsample)
    assert rec1
    assert np.all(rec1.AGEGRP >= 0) and np.all(rec1.AGEGRP <= 2)
    assert rec1.current_year == rec1.data_year
    rec1.set_current_year(rec1.data_year + 1)
    wghts_path = os.path.join(Records.CUR_PATH, Records.PIT_WEIGHTS_FILENAME)
    wghts_df = pd.read_csv(wghts_path)
    rec2 = Records(data=pit_subsample,
                   gfactors=GrowFactors(),
                   weights=wghts_df,
                   start_year=Records.PITCSV_YEAR)
    assert rec2
    assert np.all(rec1.AGEGRP >= 0) and np.all(rec1.AGEGRP <= 2)
    assert rec2.current_year == rec2.data_year
示例#15
0
def test_correct_Records_instantiation(cit_fullsample):
    rec1 = CorpRecords(data=cit_fullsample)
    # TODO: Add some checks for records
    assert True
    rec1.set_current_year(rec1.data_year + 1)
    wghts_path = os.path.join(CorpRecords.CUR_PATH,
                              CorpRecords.CIT_WEIGHTS_FILENAME)
    wghts_df = pd.read_csv(wghts_path)
    rec2 = CorpRecords(data=cit_fullsample,
                       gfactors=GrowFactors(),
                       weights=wghts_df,
                       start_year=CorpRecords.CITCSV_YEAR)
    # TODO: Repeat checks for records
    assert True
    assert rec2.current_year == rec2.data_year
def test_correct_Records_instantiation(cps_subsample):
    rec1 = Records.cps_constructor(data=cps_subsample)
    assert rec1
    assert np.all(rec1.MARS != 0)
    assert rec1.current_year == rec1.data_year
    sum_e00200_in_cps_year = rec1.e00200.sum()
    rec1.set_current_year(rec1.data_year + 1)
    sum_e00200_in_cps_year_plus_one = rec1.e00200.sum()
    assert sum_e00200_in_cps_year_plus_one == sum_e00200_in_cps_year
    wghts_path = os.path.join(Records.CUR_PATH, Records.CPS_WEIGHTS_FILENAME)
    wghts_df = pd.read_csv(wghts_path)
    rec2 = Records(data=cps_subsample,
                   exact_calculations=False,
                   gfactors=GrowFactors(),
                   weights=wghts_df,
                   start_year=Records.CPSCSV_YEAR)
    assert rec2
    assert np.all(rec2.MARS != 0)
    assert rec2.current_year == rec2.data_year
示例#17
0
def test_get_calculator_exception():
    iit_reform = {
        "II_rt1": {2017: 0.09},
        "II_rt2": {2017: 0.135},
        "II_rt3": {2017: 0.225},
        "II_rt4": {2017: 0.252},
        "II_rt5": {2017: 0.297},
        "II_rt6": {2017: 0.315},
        "II_rt7": {2017: 0.3564},
    }
    with pytest.raises(Exception):
        assert get_micro_data.get_calculator(
            baseline=False,
            calculator_start_year=TC_LAST_YEAR + 1,
            reform=iit_reform,
            data="cps",
            gfactors=GrowFactors(),
            records_start_year=CPS_START_YEAR,
        )
def test_update_and_apply_growdiff():
    syr = 2013
    nyrs = 5
    lyr = syr + nyrs - 1
    gdiff = GrowDiff(start_year=syr, num_years=nyrs)
    # update GrowDiff instance
    diffs = {2014: {'_AWAGE': [0.01]}, 2016: {'_AWAGE': [0.02]}}
    gdiff.update_growdiff(diffs)
    expected_wage_diffs = [0.00, 0.01, 0.01, 0.02, 0.02]
    assert_allclose(gdiff._AWAGE, expected_wage_diffs, atol=0.0, rtol=0.0)
    # apply growdiff to GrowFactors instance
    gf = GrowFactors()
    pir_pre = gf.price_inflation_rates(syr, lyr)
    wgr_pre = gf.wage_growth_rates(syr, lyr)
    gfactors = GrowFactors()
    gdiff.apply_to(gfactors)
    pir_pst = gfactors.price_inflation_rates(syr, lyr)
    wgr_pst = gfactors.wage_growth_rates(syr, lyr)
    expected_wgr_pst = [
        wgr_pre[i] + expected_wage_diffs[i] for i in range(0, nyrs)
    ]
    assert_allclose(pir_pre, pir_pst, atol=0.0, rtol=0.0)
    assert_allclose(wgr_pst, expected_wgr_pst, atol=1.0e-9, rtol=0.0)
def test_correct_Records_instantiation(cps_subsample):
    rec1 = Records.cps_constructor(data=cps_subsample, gfactors=None)
    assert rec1
    assert np.all(rec1.MARS != 0)
    assert rec1.current_year == rec1.data_year
    sum_e00200_in_cps_year = rec1.e00200.sum()
    rec1.increment_year()
    sum_e00200_in_cps_year_plus_one = rec1.e00200.sum()
    assert sum_e00200_in_cps_year_plus_one == sum_e00200_in_cps_year
    wghts_path = os.path.join(Records.CODE_PATH, Records.CPS_WEIGHTS_FILENAME)
    wghts_df = pd.read_csv(wghts_path)
    ratios_path = os.path.join(Records.CODE_PATH, Records.PUF_RATIOS_FILENAME)
    ratios_df = pd.read_csv(ratios_path, index_col=0).transpose()
    rec2 = Records(data=cps_subsample,
                   start_year=Records.CPSCSV_YEAR,
                   gfactors=GrowFactors(),
                   weights=wghts_df,
                   adjust_ratios=ratios_df,
                   exact_calculations=False)
    assert rec2
    assert np.all(rec2.MARS != 0)
    assert rec2.current_year == rec2.data_year
示例#20
0
文件: tbi.py 项目: keiirizawa/OG-USA
def calculators(year_n, start_year,
                use_puf_not_cps,
                use_full_sample,
                user_mods):
    """
    This function assumes that the specified user_mods is a dictionary
      returned by the Calculator.read_json_param_objects() function.
    This function returns (calc1, calc2) where
      calc1 is pre-reform Calculator object for year_n, and
      calc2 is post-reform Calculator object for year_n.
    Neither Calculator object has had the calc_all() method executed.
    """
    # pylint: disable=too-many-locals,too-many-branches,too-many-statements

    check_user_mods(user_mods)

    # specify Consumption instance
    consump = Consumption()
    consump_assumptions = user_mods['consumption']
    consump.update_consumption(consump_assumptions)

    # specify growdiff_baseline and growdiff_response
    growdiff_baseline = GrowDiff()
    growdiff_response = GrowDiff()
    growdiff_base_assumps = user_mods['growdiff_baseline']
    growdiff_resp_assumps = user_mods['growdiff_response']
    growdiff_baseline.update_growdiff(growdiff_base_assumps)
    growdiff_response.update_growdiff(growdiff_resp_assumps)

    # create pre-reform and post-reform GrowFactors instances
    growfactors_pre = GrowFactors()
    growdiff_baseline.apply_to(growfactors_pre)
    growfactors_post = GrowFactors()
    growdiff_baseline.apply_to(growfactors_post)
    growdiff_response.apply_to(growfactors_post)

    # create sample pd.DataFrame from specified input file and sampling scheme
    tbi_path = os.path.abspath(os.path.dirname(__file__))
    if use_puf_not_cps:
        # first try TaxBrain deployment path
        input_path = 'puf.csv.gz'
        if not os.path.isfile(input_path):
            # otherwise try local Tax-Calculator deployment path
            input_path = os.path.join(tbi_path, '..', '..', 'puf.csv')
        sampling_frac = 0.05
        sampling_seed = 2222
    else:  # if using cps input not puf input
        # first try Tax-Calculator code path
        input_path = os.path.join(tbi_path, '..', 'cps.csv.gz')
        if not os.path.isfile(input_path):
            # otherwise read from taxcalc package "egg"
            input_path = None  # pragma: no cover
            full_sample = read_egg_csv('cps.csv.gz')  # pragma: no cover
        sampling_frac = 0.03
        sampling_seed = 180
    if input_path:
        full_sample = pd.read_csv(input_path)
    if use_full_sample:
        sample = full_sample
    else:
        sample = full_sample.sample(frac=sampling_frac,
                                    random_state=sampling_seed)

    # create pre-reform Calculator instance
    if use_puf_not_cps:
        recs1 = Records(data=sample,
                        gfactors=growfactors_pre)
    else:
        recs1 = Records.cps_constructor(data=sample,
                                        gfactors=growfactors_pre)
    policy1 = Policy(gfactors=growfactors_pre)
    calc1 = Calculator(policy=policy1, records=recs1, consumption=consump)
    while calc1.current_year < start_year:
        calc1.increment_year()
    assert calc1.current_year == start_year

    # create post-reform Calculator instance
    if use_puf_not_cps:
        recs2 = Records(data=sample,
                        gfactors=growfactors_post)
    else:
        recs2 = Records.cps_constructor(data=sample,
                                        gfactors=growfactors_post)
    policy2 = Policy(gfactors=growfactors_post)
    policy_reform = user_mods['policy']
    policy2.implement_reform(policy_reform)
    calc2 = Calculator(policy=policy2, records=recs2, consumption=consump)
    while calc2.current_year < start_year:
        calc2.increment_year()
    assert calc2.current_year == start_year

    # delete objects now embedded in calc1 and calc2
    del sample
    del full_sample
    del consump
    del growdiff_baseline
    del growdiff_response
    del growfactors_pre
    del growfactors_post
    del recs1
    del recs2
    del policy1
    del policy2

    # increment Calculator objects for year_n years
    for _ in range(0, year_n):
        calc1.increment_year()
        calc2.increment_year()

    # return Calculator objects
    return (calc1, calc2)
def test_recs_class(recs_varinfo_file, cps_subsample):
    """
    Specify Data-derived Recs class and test it.
    """

    class Recs(Data):
        """
        The Recs class is derived from the abstract base Data class.
        """
        VARINFO_FILE_NAME = recs_varinfo_file.name
        VARINFO_FILE_PATH = ''

        def __init__(self, data, start_year, gfactors, weights):
            super().__init__(data, start_year, gfactors, weights)

        def _extrapolate(self, year):
            self.e00300 *= self.gfactors.factor_value('AINTS', year)

    # test Recs class for incorrect instantiation
    with pytest.raises(ValueError):
        Recs(data=list(), start_year=2000,
             gfactors=None, weights=None)
    with pytest.raises(ValueError):
        Recs(data=cps_subsample, start_year=list(),
             gfactors=None, weights=None)
    with pytest.raises(ValueError):
        Recs(data=cps_subsample, start_year=2000,
             gfactors=None, weights='')
    with pytest.raises(ValueError):
        Recs(data=cps_subsample, start_year=2000,
             gfactors=GrowFactors(), weights=None)
    with pytest.raises(ValueError):
        Recs(data=cps_subsample, start_year=2000,
             gfactors='', weights='')
    # test Recs class for correct instantiation with no aging of data
    syr = 2014
    rec = Recs(data=cps_subsample, start_year=syr,
               gfactors=None, weights=None)
    assert isinstance(rec, Recs)
    assert np.all(rec.MARS != 0)
    assert rec.data_year == syr
    assert rec.current_year == syr
    sum_e00300_in_syr = rec.e00300.sum()
    rec.increment_year()
    assert rec.data_year == syr
    assert rec.current_year == syr + 1
    sum_e00300_in_syr_plus_one = rec.e00300.sum()
    assert np.allclose([sum_e00300_in_syr], [sum_e00300_in_syr_plus_one])
    del rec
    # test Recs class for correct instantiation with aging of data
    wghts_path = os.path.join(GrowFactors.FILE_PATH, 'cps_weights.csv.gz')
    wghts_df = pd.read_csv(wghts_path)
    rec = Recs(data=cps_subsample, start_year=syr,
               gfactors=GrowFactors(), weights=wghts_df)
    assert isinstance(rec, Recs)
    assert np.all(rec.MARS != 0)
    assert rec.data_year == syr
    assert rec.current_year == syr
    sum_s006_in_syr = rec.s006.sum()
    sum_e00300_in_syr = rec.e00300.sum()
    rec.increment_year()
    assert rec.data_year == syr
    assert rec.current_year == syr + 1
    sum_s006_in_syr_plus_one = rec.s006.sum()
    assert sum_s006_in_syr_plus_one > sum_s006_in_syr
    sum_e00300_in_syr_plus_one = rec.e00300.sum()
    # because growfactor for e00300 was less than one in 2015, assert < below:
    assert sum_e00300_in_syr_plus_one < sum_e00300_in_syr
    # test private methods
    rec._read_data(data=None)
    rec._read_weights(weights=None)
    with pytest.raises(ValueError):
        rec._read_weights(weights=list())
示例#22
0
文件: tbi.py 项目: keiirizawa/OG-USA
def calculator_objects(year_n, start_year,
                       use_puf_not_cps,
                       use_full_sample,
                       user_mods,
                       behavior_allowed):
    """
    This function assumes that the specified user_mods is a dictionary
      returned by the Calculator.read_json_param_objects() function.
    This function returns (calc1, calc2) where
      calc1 is pre-reform Calculator object calculated for year_n, and
      calc2 is post-reform Calculator object calculated for year_n.
    Set behavior_allowed to False when generating static results or
      set behavior_allowed to True when generating dynamic results.
    """
    # pylint: disable=too-many-arguments,too-many-locals
    # pylint: disable=too-many-branches,too-many-statements

    check_user_mods(user_mods)

    # specify Consumption instance
    consump = Consumption()
    consump_assumptions = user_mods['consumption']
    consump.update_consumption(consump_assumptions)

    # specify growdiff_baseline and growdiff_response
    growdiff_baseline = GrowDiff()
    growdiff_response = GrowDiff()
    growdiff_base_assumps = user_mods['growdiff_baseline']
    growdiff_resp_assumps = user_mods['growdiff_response']
    growdiff_baseline.update_growdiff(growdiff_base_assumps)
    growdiff_response.update_growdiff(growdiff_resp_assumps)

    # create pre-reform and post-reform GrowFactors instances
    growfactors_pre = GrowFactors()
    growdiff_baseline.apply_to(growfactors_pre)
    growfactors_post = GrowFactors()
    growdiff_baseline.apply_to(growfactors_post)
    growdiff_response.apply_to(growfactors_post)

    # create sample pd.DataFrame from specified input file and sampling scheme
    tbi_path = os.path.abspath(os.path.dirname(__file__))
    if use_puf_not_cps:
        # first try TaxBrain deployment path
        input_path = 'puf.csv.gz'
        if not os.path.isfile(input_path):
            # otherwise try local Tax-Calculator deployment path
            input_path = os.path.join(tbi_path, '..', '..', 'puf.csv')
        sampling_frac = 0.05
        sampling_seed = 2222
    else:  # if using cps input not puf input
        # first try Tax-Calculator code path
        input_path = os.path.join(tbi_path, '..', 'cps.csv.gz')
        if not os.path.isfile(input_path):
            # otherwise read from taxcalc package "egg"
            input_path = None  # pragma: no cover
            full_sample = read_egg_csv('cps.csv.gz')  # pragma: no cover
        sampling_frac = 0.03
        sampling_seed = 180
    if input_path:
        full_sample = pd.read_csv(input_path)
    if use_full_sample:
        sample = full_sample
    else:
        sample = full_sample.sample(frac=sampling_frac,
                                    random_state=sampling_seed)

    # create pre-reform Calculator instance
    if use_puf_not_cps:
        recs1 = Records(data=sample,
                        gfactors=growfactors_pre)
    else:
        recs1 = Records.cps_constructor(data=sample,
                                        gfactors=growfactors_pre)
    policy1 = Policy(gfactors=growfactors_pre)
    calc1 = Calculator(policy=policy1, records=recs1, consumption=consump)
    while calc1.current_year < start_year:
        calc1.increment_year()
    calc1.calc_all()
    assert calc1.current_year == start_year

    # specify Behavior instance
    behv = Behavior()
    behavior_assumps = user_mods['behavior']
    behv.update_behavior(behavior_assumps)

    # always prevent both behavioral response and growdiff response
    if behv.has_any_response() and growdiff_response.has_any_response():
        msg = 'BOTH behavior AND growdiff_response HAVE RESPONSE'
        raise ValueError(msg)

    # optionally prevent behavioral response
    if behv.has_any_response() and not behavior_allowed:
        msg = 'A behavior RESPONSE IS NOT ALLOWED'
        raise ValueError(msg)

    # create post-reform Calculator instance
    if use_puf_not_cps:
        recs2 = Records(data=sample,
                        gfactors=growfactors_post)
    else:
        recs2 = Records.cps_constructor(data=sample,
                                        gfactors=growfactors_post)
    policy2 = Policy(gfactors=growfactors_post)
    policy_reform = user_mods['policy']
    policy2.implement_reform(policy_reform)
    calc2 = Calculator(policy=policy2, records=recs2,
                       consumption=consump, behavior=behv)
    while calc2.current_year < start_year:
        calc2.increment_year()
    assert calc2.current_year == start_year

    # delete objects now embedded in calc1 and calc2
    del sample
    del full_sample
    del consump
    del growdiff_baseline
    del growdiff_response
    del growfactors_pre
    del growfactors_post
    del behv
    del recs1
    del recs2
    del policy1
    del policy2

    # increment Calculator objects for year_n years and calculate
    for _ in range(0, year_n):
        calc1.increment_year()
        calc2.increment_year()
    calc1.calc_all()
    if calc2.behavior_has_response():
        calc2 = Behavior.response(calc1, calc2)
    else:
        calc2.calc_all()

    # return calculated Calculator objects
    return (calc1, calc2)
示例#23
0
def test_improper_usage(bad_gf_file):
    """
    Tests of improper usage of GrowFactors object.
    """
    with pytest.raises(ValueError):
        gfo = GrowFactors(dict())
    with pytest.raises(ValueError):
        gfo = GrowFactors(bad_gf_file.name)
    gfo = GrowFactors()
    fyr = gfo.first_year
    lyr = gfo.last_year
    with pytest.raises(ValueError):
        gfo.price_inflation_rates(fyr - 1, lyr)
    with pytest.raises(ValueError):
        gfo.price_inflation_rates(fyr, lyr + 1)
    with pytest.raises(ValueError):
        gfo.price_inflation_rates(lyr, fyr)
    with pytest.raises(ValueError):
        gfo.wage_growth_rates(fyr - 1, lyr)
    with pytest.raises(ValueError):
        gfo.wage_growth_rates(fyr, lyr + 1)
    with pytest.raises(ValueError):
        gfo.wage_growth_rates(lyr, fyr)
    with pytest.raises(ValueError):
        gfo.factor_value('BADNAME', fyr)
    with pytest.raises(ValueError):
        gfo.factor_value('AWAGE', fyr - 1)
    with pytest.raises(ValueError):
        gfo.factor_value('AWAGE', lyr + 1)