示例#1
0
def test_add_vars_in_df():
    # additional variables in the df shouldn't affect results
    np.random.seed(3141)
    df = make_plr_CCDDHNR2018(n_obs=100, return_type='DataFrame')
    dml_data_full_df = DoubleMLData(df, 'y', 'd', ['X1', 'X11', 'X13'])
    dml_data_subset = DoubleMLData(df[['X1', 'X11', 'X13'] + ['y', 'd']], 'y', 'd', ['X1', 'X11', 'X13'])
    dml_plr_full_df = DoubleMLPLR(dml_data_full_df, Lasso(), Lasso())
    dml_plr_subset = DoubleMLPLR(dml_data_subset, Lasso(), Lasso(), draw_sample_splitting=False)
    dml_plr_subset.set_sample_splitting(dml_plr_full_df.smpls)
    dml_plr_full_df.fit()
    dml_plr_subset.fit()
    assert np.allclose(dml_plr_full_df.coef, dml_plr_subset.coef, rtol=1e-9, atol=1e-4)
    assert np.allclose(dml_plr_full_df.se, dml_plr_subset.se, rtol=1e-9, atol=1e-4)
def test_doubleml_draw_vs_set():
    np.random.seed(3141)
    dml_plr_set = DoubleMLPLR(dml_data, ml_g, ml_m, n_folds=7, n_rep=8)

    dml_plr_drawn = DoubleMLPLR(dml_data,
                                ml_g,
                                ml_m,
                                n_folds=1,
                                n_rep=1,
                                apply_cross_fitting=False)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls)
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls[0])
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls[0][0])
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)

    dml_plr_drawn = DoubleMLPLR(dml_data,
                                ml_g,
                                ml_m,
                                n_folds=2,
                                n_rep=1,
                                apply_cross_fitting=False)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls)
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls[0])
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls[0][0])
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)

    dml_plr_drawn = DoubleMLPLR(dml_data,
                                ml_g,
                                ml_m,
                                n_folds=2,
                                n_rep=1,
                                apply_cross_fitting=True)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls)
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls[0])
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)

    dml_plr_drawn = DoubleMLPLR(dml_data,
                                ml_g,
                                ml_m,
                                n_folds=5,
                                n_rep=1,
                                apply_cross_fitting=True)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls)
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls[0])
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)

    dml_plr_drawn = DoubleMLPLR(dml_data,
                                ml_g,
                                ml_m,
                                n_folds=5,
                                n_rep=3,
                                apply_cross_fitting=True)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls)
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)

    dml_plr_drawn = DoubleMLPLR(dml_data,
                                ml_g,
                                ml_m,
                                n_folds=2,
                                n_rep=4,
                                apply_cross_fitting=False)
    dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls)
    _assert_resampling_pars(dml_plr_drawn, dml_plr_set)
dml_pliv = DoubleMLPLIV(dml_data_pliv, Lasso(), Lasso(), Lasso())
dml_pliv.fit()
dml_irm = DoubleMLIRM(dml_data_irm, Lasso(), LogisticRegression())
dml_irm.fit()
dml_iivm = DoubleMLIIVM(dml_data_iivm, Lasso(), LogisticRegression(),
                        LogisticRegression())
dml_iivm.fit()

# fit models with callable scores
plr_score = dml_plr._score_elements
dml_plr_callable_score = DoubleMLPLR(dml_data_plr,
                                     Lasso(),
                                     Lasso(),
                                     score=plr_score,
                                     draw_sample_splitting=False)
dml_plr_callable_score.set_sample_splitting(dml_plr.smpls)
dml_plr_callable_score.fit(store_predictions=True)

irm_score = dml_irm._score_elements
dml_irm_callable_score = DoubleMLIRM(dml_data_irm,
                                     Lasso(),
                                     LogisticRegression(),
                                     score=irm_score,
                                     draw_sample_splitting=False)
dml_irm_callable_score.set_sample_splitting(dml_irm.smpls)
dml_irm_callable_score.fit(store_predictions=True)

iivm_score = dml_iivm._score_elements
dml_iivm_callable_score = DoubleMLIIVM(dml_data_iivm,
                                       Lasso(),
                                       LogisticRegression(),