def test_can_use_interpreters(self):
     n = 100
     for t_shape in [(n, ), (n, 1)]:
         for y_shape in [(n, ), (n, 1)]:
             X = np.random.normal(size=(n, 4))
             T = np.random.binomial(1, 0.5, size=t_shape)
             Y = (T.flatten() * (2 * (X[:, 0] > 0) - 1)).reshape(y_shape)
             est = LinearDML(model_y=LinearRegression(),
                             model_t=LogisticRegression(),
                             discrete_treatment=True)
             est.fit(Y, T, X=X)
             for intrp in [
                     SingleTreeCateInterpreter(),
                     SingleTreePolicyInterpreter()
             ]:
                 with self.subTest(t_shape=t_shape,
                                   y_shape=y_shape,
                                   intrp=intrp):
                     with self.assertRaises(Exception):
                         # prior to calling interpret, can't plot, render, etc.
                         intrp.plot()
                     intrp.interpret(est, X)
                     intrp.plot()
                     intrp.render('tmp.pdf', view=False)
                     intrp.export_graphviz()
Esempio n. 2
0
    def test_cate_uncertainty_needs_inference(self):
        n = 100
        X = np.random.normal(size=(n, 4))
        T = np.random.binomial(1, 0.5, size=(n,))
        Y = np.random.normal(size=(n,))
        est = LinearDML(discrete_treatment=True)
        est.fit(Y, T, X=X, inference=None)

        # can interpret without uncertainty
        intrp = SingleTreeCateInterpreter()
        intrp.interpret(est, X)

        intrp = SingleTreeCateInterpreter(include_model_uncertainty=True)
        with self.assertRaises(Exception):
            # can't interpret with uncertainty if inference wasn't used during fit
            intrp.interpret(est, X)

        # can interpret with uncertainty if we refit
        est.fit(Y, T, X=X)
        intrp.interpret(est, X)
 def test_can_use_interpreters(self):
     n = 100
     for t_shape in [(n, ), (n, 1)]:
         for y_shape in [(n, ), (n, 1)]:
             X = np.random.normal(size=(n, 4))
             T = np.random.binomial(1, 0.5, size=t_shape)
             Y = np.random.normal(size=y_shape)
             est = LinearDMLCateEstimator(discrete_treatment=True)
             est.fit(Y, T, X)
             for intrp in [
                     SingleTreeCateInterpreter(),
                     SingleTreePolicyInterpreter()
             ]:
                 with self.subTest(t_shape=t_shape,
                                   y_shape=y_shape,
                                   intrp=intrp):
                     with self.assertRaises(Exception):
                         # prior to calling interpret, can't plot, render, etc.
                         intrp.plot()
                     intrp.interpret(est, X)
                     intrp.plot()
                     intrp.render('tmp.pdf', view=False)
                     intrp.export_graphviz()
Esempio n. 4
0
    "point_estimate", "pvalue", "stderr"
]]
dml2_summary.index = X_eval.index
dml2_summary["star"] = dml2_summary["pvalue"].apply(stars)
dml2_summary["point_estimate"] = dml2_summary["point_estimate"].apply(
    format_float, digits=4)
dml2_summary["point_estimate"] = dml2_summary["point_estimate"].str.cat(
    dml2_summary["star"])
dml2_summary["stderr"] = dml2_summary["stderr"].apply(surr_parenthesis,
                                                      digits=4)
dml2_summary = dml2_summary[["point_estimate", "stderr"]].stack()
dml2_summary.name = "Ameaça"

# Interpretacao por arvore de decisao para T2
interp = SingleTreeCateInterpreter(include_model_uncertainty=False,
                                   max_depth=3,
                                   min_samples_leaf=10)
interp.interpret(dml, X2)
fig, ax1 = plt.subplots(figsize=(25, 6))
interp.plot(feature_names=X_cols, fontsize=12, ax=ax1)
fig.savefig("Figs/fig_tree_dml.png")

# DML para T3
dml.fit(Y3, T3, X3, inference='auto')
dml3_eff = dml.effect(X3, T0=0, T1=1)
dml3_eff_treat = dml.effect(X3_treat, T0=0, T1=1)
print(
    f"ATE T3 por DML: {np.mean(dml3_eff)}\nATT T3 por DML: {np.mean(dml3_eff_treat)}"
)
dml3_inf = dml.effect_inference(X_eval.values)
dml3_summary = dml3_inf.summary_frame(alpha=0.05)[[
    def test_random_cate_settings(self):
        """Verify that we can call methods on the CATE interpreter with various combinations of inputs"""
        n = 100
        for _ in range(100):
            t_shape = (n, ) if self.coinflip() else (n, 1)
            y_shape = (n, ) if self.coinflip() else (n, 1)
            discrete_t = self.coinflip()
            X = np.random.normal(size=(n, 4))
            X2 = np.random.normal(size=(10, 4))
            T = np.random.binomial(
                2, 0.5, size=t_shape) if discrete_t else np.random.normal(
                    size=t_shape)
            Y = ((T.flatten() == 1) * (2 * (X[:, 0] > 0) - 1) +
                 (T.flatten() == 2) * (2 * (X[:, 1] > 0) - 1)).reshape(y_shape)

            if self.coinflip():
                y_shape = (n, 2)
                Y = np.tile(Y.reshape((-1, 1)), (1, 2))

            est = LinearDML(model_y=LinearRegression(),
                            model_t=LogisticRegression()
                            if discrete_t else LinearRegression(),
                            discrete_treatment=discrete_t)

            fit_kwargs = {}
            cate_init_kwargs = {}
            policy_init_kwargs = {}
            intrp_kwargs = {}
            policy_intrp_kwargs = {}
            common_kwargs = {}
            plot_kwargs = {}
            render_kwargs = {}
            export_kwargs = {}

            if self.coinflip():
                cate_init_kwargs.update(include_model_uncertainty=True)
                policy_init_kwargs.update(risk_level=0.1)
            else:
                fit_kwargs.update(inference=None)

            if self.coinflip():
                cate_init_kwargs.update(uncertainty_level=0.01)

            if self.coinflip():
                policy_init_kwargs.update(risk_seeking=True)

            if self.coinflip(1 / 3):
                policy_intrp_kwargs.update(sample_treatment_costs=0.1)
            elif self.coinflip():
                if discrete_t:
                    policy_intrp_kwargs.update(
                        sample_treatment_costs=np.random.normal(size=(10, 2)))
                else:
                    if self.coinflip():
                        policy_intrp_kwargs.update(
                            sample_treatment_costs=np.random.normal(size=(10,
                                                                          1)))
                    else:
                        policy_intrp_kwargs.update(
                            sample_treatment_costs=np.random.normal(
                                size=(10, )))

            if self.coinflip():
                common_kwargs.update(feature_names=['A', 'B', 'C', 'D'])

            if self.coinflip():
                common_kwargs.update(filled=False)

            if self.coinflip():
                common_kwargs.update(rounded=False)

            if self.coinflip():
                common_kwargs.update(precision=1)

            if self.coinflip():
                render_kwargs.update(rotate=True)
                export_kwargs.update(rotate=True)

            if self.coinflip():
                render_kwargs.update(leaves_parallel=False)
                export_kwargs.update(leaves_parallel=False)
                if discrete_t:
                    render_kwargs.update(treatment_names=[
                        'control gp', 'treated gp', 'more gp'
                    ])
                    export_kwargs.update(treatment_names=[
                        'control gp', 'treated gp', 'more gp'
                    ])
                else:
                    render_kwargs.update(
                        treatment_names=['control gp', 'treated gp'])
                    export_kwargs.update(
                        treatment_names=['control gp', 'treated gp'])

            if self.coinflip():
                render_kwargs.update(format='png')

            if self.coinflip():
                export_kwargs.update(out_file='out')

            if self.coinflip(0.95):  # don't launch files most of the time
                render_kwargs.update(view=False)

            with self.subTest(t_shape=t_shape,
                              y_shape=y_shape,
                              discrete_t=discrete_t,
                              fit_kwargs=fit_kwargs,
                              cate_init_kwargs=cate_init_kwargs,
                              policy_init_kwargs=policy_init_kwargs,
                              policy_intrp_kwargs=policy_intrp_kwargs,
                              intrp_kwargs=intrp_kwargs,
                              common_kwargs=common_kwargs,
                              plot_kwargs=plot_kwargs,
                              render_kwargs=render_kwargs,
                              export_kwargs=export_kwargs):
                plot_kwargs.update(common_kwargs)
                render_kwargs.update(common_kwargs)
                export_kwargs.update(common_kwargs)
                policy_intrp_kwargs.update(intrp_kwargs)

                est.fit(Y, T, X=X, **fit_kwargs)

                intrp = SingleTreeCateInterpreter(**cate_init_kwargs)
                intrp.interpret(est, X2, **intrp_kwargs)
                intrp.plot(**plot_kwargs)
                intrp.render('outfile', **render_kwargs)
                intrp.export_graphviz(**export_kwargs)

                intrp = SingleTreePolicyInterpreter(**policy_init_kwargs)
                try:
                    intrp.interpret(est, X2, **policy_intrp_kwargs)
                    intrp.plot(**plot_kwargs)
                    intrp.render('outfile', **render_kwargs)
                    intrp.export_graphviz(**export_kwargs)
                except AttributeError as e:
                    assert str(e).find("samples should") >= 0