예제 #1
0
 def test_orthoiv(self):
     X = TestPandasIntegration.df[TestPandasIntegration.features]
     Y = TestPandasIntegration.df[TestPandasIntegration.outcome]
     T = TestPandasIntegration.df[TestPandasIntegration.bin_treat]
     Z = TestPandasIntegration.df[TestPandasIntegration.instrument]
     # Test LinearIntentToTreatDRIV
     est = LinearIntentToTreatDRIV(
         model_y_xw=GradientBoostingRegressor(),
         model_t_xwz=GradientBoostingClassifier(),
         flexible_model_effect=GradientBoostingRegressor())
     est.fit(Y, T, Z=Z, X=X, inference='statsmodels')
     treatment_effects = est.effect(X)
     lb, ub = est.effect_interval(X, alpha=0.05)
     self._check_input_names(est.summary())  # Check input names propagate
     self._check_popsum_names(est.effect_inference(X).population_summary())
예제 #2
0
    def test_can_use_statsmodel_inference(self):
        """Test that we can use statsmodels to generate confidence intervals"""
        est = LinearIntentToTreatDRIV(model_Y_X=LinearRegression(),
                                      model_T_XZ=LogisticRegression(C=1000),
                                      flexible_model_effect=WeightedLasso())
        est.fit(np.array([1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2]),
                np.array([1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2]),
                Z=np.array([1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]),
                X=np.array([1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5,
                            6]).reshape(-1, 1))
        interval = est.effect_interval(np.ones((9, 1)),
                                       T0=np.array([1, 1, 1, 2, 2, 2, 1, 1,
                                                    1]),
                                       T1=np.array([1, 2, 1, 1, 2, 2, 2, 2,
                                                    1]),
                                       alpha=0.05)
        point = est.effect(np.ones((9, 1)),
                           T0=np.array([1, 1, 1, 2, 2, 2, 1, 1, 1]),
                           T1=np.array([1, 2, 1, 1, 2, 2, 2, 2, 1]))

        assert len(interval) == 2
        lo, hi = interval
        assert lo.shape == hi.shape == point.shape
        assert np.all(lo <= point)
        assert np.all(point <= hi)
        assert np.any(
            lo < hi
        )  # for at least some of the examples, the CI should have nonzero width

        interval = est.const_marginal_effect_interval(np.ones((9, 1)),
                                                      alpha=0.05)
        point = est.const_marginal_effect(np.ones((9, 1)))
        assert len(interval) == 2
        lo, hi = interval
        assert lo.shape == hi.shape == point.shape
        assert np.all(lo <= point)
        assert np.all(point <= hi)
        assert np.any(
            lo < hi
        )  # for at least some of the examples, the CI should have nonzero width

        interval = est.coef__interval(alpha=0.05)
        point = est.coef_
        assert len(interval) == 2
        lo, hi = interval
        assert lo.shape == hi.shape == point.shape
        assert np.all(lo <= point)
        assert np.all(point <= hi)
        assert np.any(
            lo < hi
        )  # for at least some of the examples, the CI should have nonzero width

        interval = est.intercept__interval(alpha=0.05)
        point = est.intercept_
        assert len(interval) == 2
        lo, hi = interval
        assert np.all(lo <= point)
        assert np.all(point <= hi)
        assert np.any(
            lo < hi
        )  # for at least some of the examples, the CI should have nonzero width