コード例 #1
0
class CausalForestEstimator(BaseITEEstimator):
    def fit(self, **kwargs):
        if self.model is None:
            self.model = CausalForestDML()
        self.model.fit(Y=kwargs["y_training"],
                       T=kwargs["t_training"],
                       X=kwargs["X_training"])

    def predict(self, **kwargs):
        preds = self.model.effect_inference(kwargs["X"])
        if "return_mean" in kwargs:
            out = preds.pred
        else:
            out = (preds.pred, preds.var)
        return out
コード例 #2
0
ファイル: test_dowhy.py プロジェクト: vishalbelsare/EconML
    def test_dowhy(self):
        def reg():
            return LinearRegression()

        def clf():
            return LogisticRegression()

        Y, T, X, W, Z = self._get_data()
        # test at least one estimator from each category
        models = {"dml": LinearDML(model_y=reg(), model_t=clf(), discrete_treatment=True,
                                   linear_first_stages=False),
                  "dr": DRLearner(model_propensity=clf(), model_regression=reg(),
                                  model_final=reg()),
                  "forestdr": ForestDRLearner(model_propensity=clf(), model_regression=reg()),
                  "xlearner": XLearner(models=reg(), cate_models=reg(), propensity_model=clf()),
                  "cfdml": CausalForestDML(model_y=reg(), model_t=clf(), discrete_treatment=True),
                  "orf": DROrthoForest(n_trees=10, propensity_model=clf(), model_Y=reg()),
                  "orthoiv": OrthoIV(model_y_xw=reg(),
                                     model_t_xw=clf(),
                                     model_z_xw=reg(),
                                     discrete_treatment=True,
                                     discrete_instrument=False),
                  "dmliv": DMLIV(fit_cate_intercept=True,
                                 discrete_treatment=True,
                                 discrete_instrument=False),
                  "driv": LinearDRIV(flexible_model_effect=StatsModelsLinearRegression(fit_intercept=False),
                                     fit_cate_intercept=True,
                                     discrete_instrument=False,
                                     discrete_treatment=True)}
        for name, model in models.items():
            with self.subTest(name=name):
                est = model
                if name == "xlearner":
                    est_dowhy = est.dowhy.fit(Y, T, X=np.hstack((X, W)), W=None)
                elif name in ["orthoiv", "dmliv", "driv"]:
                    est_dowhy = est.dowhy.fit(Y, T, Z=Z, X=X, W=W)
                else:
                    est_dowhy = est.dowhy.fit(Y, T, X=X, W=W)
                # test causal graph
                est_dowhy.view_model()
                # test refutation estimate
                est_dowhy.refute_estimate(method_name="random_common_cause")
                if name != "orf":
                    est_dowhy.refute_estimate(method_name="add_unobserved_common_cause",
                                              confounders_effect_on_treatment="binary_flip",
                                              confounders_effect_on_outcome="linear",
                                              effect_strength_on_treatment=0.1,
                                              effect_strength_on_outcome=0.1,)
                    est_dowhy.refute_estimate(method_name="placebo_treatment_refuter", placebo_type="permute",
                                              num_simulations=3)
                    est_dowhy.refute_estimate(method_name="data_subset_refuter", subset_fraction=0.8,
                                              num_simulations=3)
コード例 #3
0
 def test_dml_random_state(self):
     Y, T, X, W, X_test = TestRandomState._make_data(500, 2)
     for est in [
             NonParamDML(model_y=RandomForestRegressor(n_estimators=10,
                                                       max_depth=4,
                                                       random_state=123),
                         model_t=RandomForestClassifier(n_estimators=10,
                                                        max_depth=4,
                                                        random_state=123),
                         model_final=RandomForestRegressor(
                             max_depth=3,
                             n_estimators=10,
                             min_samples_leaf=100,
                             bootstrap=True,
                             random_state=123),
                         discrete_treatment=True,
                         n_splits=2,
                         random_state=123),
             CausalForestDML(
                 model_y=RandomForestRegressor(n_estimators=10,
                                               max_depth=4,
                                               random_state=123),
                 model_t=RandomForestClassifier(n_estimators=10,
                                                max_depth=4,
                                                random_state=123),
                 n_estimators=8,
                 discrete_treatment=True,
                 cv=2,
                 random_state=123),
             LinearDML(model_y=RandomForestRegressor(n_estimators=10,
                                                     max_depth=4,
                                                     random_state=123),
                       model_t=RandomForestClassifier(n_estimators=10,
                                                      max_depth=4,
                                                      random_state=123),
                       discrete_treatment=True,
                       n_splits=2,
                       random_state=123),
             SparseLinearDML(discrete_treatment=True,
                             n_splits=2,
                             random_state=123),
             KernelDML(discrete_treatment=True,
                       n_splits=2,
                       random_state=123)
     ]:
         TestRandomState._test_random_state(est, X_test, Y, T, X=X, W=W)
コード例 #4
0
 def fit(self, **kwargs):
     if self.model is None:
         self.model = CausalForestDML()
     self.model.fit(Y=kwargs["y_training"],
                    T=kwargs["t_training"],
                    X=kwargs["X_training"])