Ejemplo n.º 1
0
 def test_parameter_passing(self):
     for gen in [DML, NonParamDML]:
         est = gen(model_y=LinearRegression(), model_t=LinearRegression(),
                   model_final=LinearRegression(),
                   mc_iters=2, mc_agg='median')
         assert est.mc_iters == 2
         assert est.mc_agg == 'median'
     for gen in [LinearDML, SparseLinearDML, KernelDML, ForestDML]:
         est = gen(model_y=LinearRegression(), model_t=LinearRegression(),
                   mc_iters=2, mc_agg='median')
         assert est.mc_iters == 2
         assert est.mc_agg == 'median'
     for gen in [DRLearner, LinearDRLearner, SparseLinearDRLearner, ForestDRLearner]:
         est = gen(mc_iters=2, mc_agg='median')
         assert est.mc_iters == 2
         assert est.mc_agg == 'median'
     for gen in [DMLATEIV(model_Y_W=LinearRegression(),
                          model_T_W=LinearRegression(),
                          model_Z_W=LinearRegression(), mc_iters=2, mc_agg='median'),
                 ProjectedDMLATEIV(model_Y_W=LinearRegression(),
                                   model_T_W=LinearRegression(),
                                   model_T_WZ=LinearRegression(), mc_iters=2, mc_agg='median'),
                 DMLIV(model_Y_X=LinearRegression(),
                       model_T_X=LinearRegression(),
                       model_T_XZ=LinearRegression(),
                       model_final=LinearRegression(), mc_iters=2, mc_agg='median'),
                 NonParamDMLIV(model_Y_X=LinearRegression(),
                               model_T_X=LinearRegression(),
                               model_T_XZ=LinearRegression(),
                               model_final=LinearRegression(), mc_iters=2, mc_agg='median'),
                 IntentToTreatDRIV(model_Y_X=LinearRegression(),
                                   model_T_XZ=LinearRegression(),
                                   flexible_model_effect=LinearRegression(), mc_iters=2, mc_agg='median'),
                 LinearIntentToTreatDRIV(model_Y_X=LinearRegression(),
                                         model_T_XZ=LinearRegression(),
                                         flexible_model_effect=LinearRegression(),
                                         mc_iters=2, mc_agg='median')]:
         assert est.mc_iters == 2
         assert est.mc_agg == 'median'
Ejemplo n.º 2
0
 def test_orthoiv_random_state(self):
     Y, T, X, W, X_test = self._make_data(500, 2)
     for est in [
             DMLATEIV(model_Y_W=RandomForestRegressor(n_estimators=10,
                                                      max_depth=4,
                                                      random_state=123),
                      model_T_W=RandomForestClassifier(n_estimators=10,
                                                       max_depth=4,
                                                       random_state=123),
                      model_Z_W=RandomForestClassifier(n_estimators=10,
                                                       max_depth=4,
                                                       random_state=123),
                      discrete_treatment=True,
                      discrete_instrument=True,
                      cv=2,
                      random_state=123),
             ProjectedDMLATEIV(
                 model_Y_W=RandomForestRegressor(n_estimators=10,
                                                 max_depth=4,
                                                 random_state=123),
                 model_T_W=RandomForestClassifier(n_estimators=10,
                                                  max_depth=4,
                                                  random_state=123),
                 model_T_WZ=RandomForestClassifier(n_estimators=10,
                                                   max_depth=4,
                                                   random_state=123),
                 discrete_treatment=True,
                 discrete_instrument=True,
                 cv=2,
                 random_state=123)
     ]:
         TestRandomState._test_random_state(est, None, Y, T, W=W, Z=T)
     for est in [
             DMLIV(model_Y_X=RandomForestRegressor(n_estimators=10,
                                                   max_depth=4,
                                                   random_state=123),
                   model_T_X=RandomForestClassifier(n_estimators=10,
                                                    max_depth=4,
                                                    random_state=123),
                   model_T_XZ=RandomForestClassifier(n_estimators=10,
                                                     max_depth=4,
                                                     random_state=123),
                   model_final=LinearRegression(fit_intercept=False),
                   discrete_treatment=True,
                   discrete_instrument=True,
                   cv=2,
                   random_state=123),
             NonParamDMLIV(
                 model_Y_X=RandomForestRegressor(n_estimators=10,
                                                 max_depth=4,
                                                 random_state=123),
                 model_T_X=RandomForestClassifier(n_estimators=10,
                                                  max_depth=4,
                                                  random_state=123),
                 model_T_XZ=RandomForestClassifier(n_estimators=10,
                                                   max_depth=4,
                                                   random_state=123),
                 model_final=LinearRegression(),
                 discrete_treatment=True,
                 discrete_instrument=True,
                 cv=2,
                 random_state=123)
     ]:
         TestRandomState._test_random_state(est, X_test, Y, T, X=X, Z=T)
     for est in [
             IntentToTreatDRIV(
                 model_Y_X=RandomForestRegressor(n_estimators=10,
                                                 max_depth=4,
                                                 random_state=123),
                 model_T_XZ=RandomForestClassifier(n_estimators=10,
                                                   max_depth=4,
                                                   random_state=123),
                 flexible_model_effect=RandomForestRegressor(
                     n_estimators=10, max_depth=4, random_state=123),
                 cv=2,
                 random_state=123),
             LinearIntentToTreatDRIV(
                 model_Y_X=RandomForestRegressor(n_estimators=10,
                                                 max_depth=4,
                                                 random_state=123),
                 model_T_XZ=RandomForestClassifier(n_estimators=10,
                                                   max_depth=4,
                                                   random_state=123),
                 flexible_model_effect=RandomForestRegressor(
                     n_estimators=10, max_depth=4, random_state=123),
                 cv=2,
                 random_state=123)
     ]:
         TestRandomState._test_random_state(est,
                                            X_test,
                                            Y,
                                            T,
                                            X=X,
                                            W=W,
                                            Z=T)
Ejemplo n.º 3
0
    def test_orthoiv(self):
        y, T, X, W = self._get_data()
        Z = T.copy()
        est = DMLATEIV(model_Y_W=LinearRegression(),
                       model_T_W=LinearRegression(),
                       model_Z_W=LinearRegression(),
                       mc_iters=2)
        est.fit(y, T, W=W, Z=Z, cache_values=True)
        est.refit_final()
        est.model_Y_W = Lasso()
        est.model_T_W = ElasticNet()
        est.model_Z_W = WeightedLasso()
        est.fit(y, T, W=W, Z=Z, cache_values=True)
        assert isinstance(est.models_nuisance_[0]._model_Y_W._model, Lasso)
        assert isinstance(est.models_nuisance_[0]._model_T_W._model, ElasticNet)
        assert isinstance(est.models_nuisance_[0]._model_Z_W._model, WeightedLasso)

        est = ProjectedDMLATEIV(model_Y_W=LinearRegression(),
                                model_T_W=LinearRegression(),
                                model_T_WZ=LinearRegression(),
                                mc_iters=2)
        est.fit(y, T, W=W, Z=Z, cache_values=True)
        est.refit_final()
        est.model_Y_W = Lasso()
        est.model_T_W = ElasticNet()
        est.model_T_WZ = WeightedLasso()
        est.fit(y, T, W=W, Z=Z, cache_values=True)
        assert isinstance(est.models_nuisance_[0]._model_Y_W._model, Lasso)
        assert isinstance(est.models_nuisance_[0]._model_T_W._model, ElasticNet)
        assert isinstance(est.models_nuisance_[0]._model_T_WZ._model, WeightedLasso)

        est = DMLIV(model_Y_X=LinearRegression(),
                    model_T_X=LinearRegression(),
                    model_T_XZ=LinearRegression(),
                    model_final=LinearRegression(fit_intercept=False),
                    mc_iters=2)
        est.fit(y, T, X=X, Z=Z, cache_values=True)
        np.testing.assert_equal(len(est.coef_), X.shape[1])
        est.featurizer = PolynomialFeatures(degree=2, include_bias=False)
        est.refit_final()
        np.testing.assert_equal(len(est.coef_), X.shape[1]**2)
        est.intercept_
        est.fit_cate_intercept = False
        est.intercept_
        est.refit_final()
        with pytest.raises(AttributeError):
            est.intercept_
        est.model_Y_X = Lasso()
        est.model_T_X = ElasticNet()
        est.model_T_XZ = WeightedLasso()
        est.fit(y, T, X=X, Z=Z, cache_values=True)
        assert isinstance(est.models_Y_X[0], Lasso)
        assert isinstance(est.models_T_X[0], ElasticNet)
        assert isinstance(est.models_T_XZ[0], WeightedLasso)

        est = DMLIV(model_Y_X=LinearRegression(),
                    model_T_X=LinearRegression(),
                    model_T_XZ=LinearRegression(),
                    model_final=LinearRegression(fit_intercept=False),
                    mc_iters=2)
        est.fit(y, T, X=X, Z=Z, cache_values=True)
        np.testing.assert_equal(len(est.coef_), X.shape[1])
        est.featurizer = PolynomialFeatures(degree=2, include_bias=False)
        est.refit_final()
        np.testing.assert_equal(len(est.coef_), X.shape[1]**2)
        est.intercept_
        est.fit_cate_intercept = False
        est.intercept_
        est.refit_final()
        with pytest.raises(AttributeError):
            est.intercept_
        est.model_Y_X = Lasso()
        est.model_T_X = ElasticNet()
        est.model_T_XZ = WeightedLasso()
        est.fit(y, T, X=X, Z=Z, cache_values=True)
        assert isinstance(est.models_nuisance_[0]._model_Y_X._model, Lasso)
        assert isinstance(est.models_nuisance_[0]._model_T_X._model, ElasticNet)
        assert isinstance(est.models_nuisance_[0]._model_T_XZ._model, WeightedLasso)

        est = NonParamDMLIV(model_Y_X=LinearRegression(),
                            model_T_X=LinearRegression(),
                            model_T_XZ=LinearRegression(),
                            model_final=LinearRegression(fit_intercept=True),
                            mc_iters=2)
        est.fit(y, T, X=X, Z=Z, cache_values=True)
        est.featurizer = PolynomialFeatures(degree=2, include_bias=False)
        est.model_final = WeightedLasso()
        est.refit_final()
        assert isinstance(est.model_cate, WeightedLasso)
        assert isinstance(est.featurizer_, PolynomialFeatures)

        est = IntentToTreatDRIV(model_Y_X=LinearRegression(), model_T_XZ=LogisticRegression(),
                                flexible_model_effect=LinearRegression())
        est.fit(y, T, X=X, W=W, Z=Z, cache_values=True)
        assert est.model_final is None
        assert isinstance(est.model_final_, LinearRegression)
        est.flexible_model_effect = Lasso()
        est.refit_final()
        assert est.model_final is None
        assert isinstance(est.model_final_, Lasso)
        est.model_final = Lasso()
        est.refit_final()
        assert isinstance(est.model_final, Lasso)
        assert isinstance(est.model_final_, Lasso)
        assert isinstance(est.models_nuisance_[0]._prel_model_effect.model_final_, LinearRegression)
        est.fit(y, T, X=X, W=W, Z=Z, cache_values=True)
        assert isinstance(est.models_nuisance_[0]._prel_model_effect.model_final_, Lasso)

        est = LinearIntentToTreatDRIV(model_Y_X=LinearRegression(), model_T_XZ=LogisticRegression(),
                                      flexible_model_effect=LinearRegression())
        est.fit(y, T, X=X, W=W, Z=Z, cache_values=True)
        est.fit_cate_intercept = False
        est.intercept_
        est.intercept__interval()
        est.refit_final()
        with pytest.raises(AttributeError):
            est.intercept_
        with pytest.raises(AttributeError):
            est.intercept__interval()
        with pytest.raises(ValueError):
            est.model_final = LinearRegression()
        est.flexible_model_effect = Lasso()
        est.fit(y, T, X=X, W=W, Z=Z, cache_values=True)
        assert isinstance(est.models_nuisance_[0]._prel_model_effect.model_final_, Lasso)
Ejemplo n.º 4
0
    def test_bad_splits_discrete(self):
        """
        Tests that when some training splits in a crossfit fold don't contain all treatments then an error
        is raised.
        """
        Y = np.array([2, 3, 1, 3, 2, 1, 1, 1])
        bad = np.array([2, 2, 1, 2, 1, 1, 1, 1])
        W = np.ones((8, 1))
        ok = np.array([1, 2, 3, 1, 2, 3, 1, 2])
        models = [Lasso(), Lasso(), Lasso()]
        est = DMLATEIV(*models, n_splits=[(np.arange(4, 8), np.arange(4))])
        est.fit(Y, T=bad, Z=bad,
                W=W)  # imbalance ok with continuous instrument/treatment

        models = [Lasso(), LogisticRegression(), Lasso()]
        est = DMLATEIV(*models,
                       n_splits=[(np.arange(4, 8), np.arange(4))],
                       discrete_treatment=True)
        with pytest.raises(AttributeError):
            est.fit(Y, T=bad, Z=ok, W=W)

        models = [Lasso(), Lasso(), LogisticRegression()]
        est = DMLATEIV(*models,
                       n_splits=[(np.arange(4, 8), np.arange(4))],
                       discrete_instrument=True)
        with pytest.raises(AttributeError):
            est.fit(Y, T=ok, Z=bad, W=W)
Ejemplo n.º 5
0
    def test_cate_api(self):
        """Test that we correctly implement the CATE API."""
        n = 30

        def size(n, d):
            return (n, d) if d >= 0 else (n, )

        def make_random(is_discrete, d):
            if d is None:
                return None
            sz = size(n, d)
            if is_discrete:
                while True:
                    arr = np.random.choice(['a', 'b', 'c'], size=sz)
                    # ensure that we've got at least two of every row
                    _, counts = np.unique(arr, return_counts=True, axis=0)
                    if len(counts) == 3**(d if d > 0 else
                                          1) and counts.min() > 1:
                        return arr
            else:
                return np.random.normal(size=sz)

        def eff_shape(n, d_y):
            return (n, ) + ((d_y, ) if d_y > 0 else ())

        def marg_eff_shape(n, d_y, d_t_final):
            return ((n, ) + ((d_y, ) if d_y > 0 else
                             ()) + ((d_t_final, ) if d_t_final > 0 else ()))

        # since T isn't passed to const_marginal_effect, defaults to one row if X is None
        def const_marg_eff_shape(n, d_x, d_y, d_t_final):
            return ((n if d_x else 1, ) + ((d_y, ) if d_y > 0 else ()) +
                    ((d_t_final, ) if d_t_final > 0 else ()))

        for d_t in [2, 1, -1]:
            n_t = d_t if d_t > 0 else 1
            for discrete_t in [True, False] if n_t == 1 else [False]:
                for d_y in [3, 1, -1]:
                    for d_q in [2, None]:
                        for d_z in [2, 1]:
                            if d_z < n_t:
                                continue
                            for discrete_z in [True, False
                                               ] if d_z == 1 else [False]:
                                Z1, Q, Y, T1 = [
                                    make_random(is_discrete, d)
                                    for is_discrete, d in [(
                                        discrete_z,
                                        d_z), (False,
                                               d_q), (False,
                                                      d_y), (discrete_t, d_t)]
                                ]
                                if discrete_t and discrete_z:
                                    # need to make sure we get all *joint* combinations
                                    arr = make_random(True, 2)
                                    Z1 = arr[:, 0].reshape(size(n, d_z))
                                    T1 = arr[:, 0].reshape(size(n, d_t))

                                d_t_final1 = 2 if discrete_t else d_t

                                if discrete_t:
                                    # IntentToTreat only supports binary treatments/instruments
                                    T2 = T1.copy()
                                    T2[T1 == 'c'] = np.random.choice(
                                        ['a', 'b'],
                                        size=np.count_nonzero(T1 == 'c'))
                                    d_t_final2 = 1
                                if discrete_z:
                                    # IntentToTreat only supports binary treatments/instruments
                                    Z2 = Z1.copy()
                                    Z2[Z1 == 'c'] = np.random.choice(
                                        ['a', 'b'],
                                        size=np.count_nonzero(Z1 == 'c'))

                                effect_shape = eff_shape(n, d_y)

                                model_t = LogisticRegression(
                                ) if discrete_t else Lasso()
                                model_z = LogisticRegression(
                                ) if discrete_z else Lasso()

                                # TODO: add stratification to bootstrap so that we can use it
                                # even with discrete treatments
                                all_infs = [None]
                                if not (discrete_t or discrete_z):
                                    all_infs.append(BootstrapInference(1))

                                estimators = [
                                    (DMLATEIV(model_Y_W=Lasso(),
                                              model_T_W=model_t,
                                              model_Z_W=model_z,
                                              discrete_treatment=discrete_t,
                                              discrete_instrument=discrete_z),
                                     True, all_infs),
                                    (ProjectedDMLATEIV(
                                        model_Y_W=Lasso(),
                                        model_T_W=model_t,
                                        model_T_WZ=model_t,
                                        discrete_treatment=discrete_t,
                                        discrete_instrument=discrete_z), False,
                                     all_infs),
                                    (DMLIV(model_Y_X=Lasso(),
                                           model_T_X=model_t,
                                           model_T_XZ=model_t,
                                           model_final=Lasso(),
                                           discrete_treatment=discrete_t,
                                           discrete_instrument=discrete_z),
                                     False, all_infs)
                                ]

                                if d_q and discrete_t and discrete_z:
                                    # IntentToTreat requires X
                                    estimators.append((LinearIntentToTreatDRIV(
                                        model_Y_X=Lasso(),
                                        model_T_XZ=model_t,
                                        flexible_model_effect=WeightedLasso(),
                                        n_splits=2), False, all_infs +
                                                       ['statsmodels']))

                                for est, multi, infs in estimators:
                                    if not (
                                            multi
                                    ) and d_y > 1 or d_t > 1 or d_z > 1:
                                        continue

                                    # ensure we can serialize unfit estimator
                                    pickle.dumps(est)

                                    d_ws = [None]
                                    if isinstance(est,
                                                  LinearIntentToTreatDRIV):
                                        d_ws.append(2)

                                    for d_w in d_ws:
                                        W = make_random(False, d_w)

                                        for inf in infs:
                                            with self.subTest(
                                                    d_z=d_z,
                                                    d_x=d_q,
                                                    d_y=d_y,
                                                    d_t=d_t,
                                                    discrete_t=discrete_t,
                                                    discrete_z=discrete_z,
                                                    est=est,
                                                    inf=inf):
                                                Z = Z1
                                                T = T1
                                                d_t_final = d_t_final1
                                                X = Q
                                                d_x = d_q

                                                if isinstance(
                                                        est,
                                                    (DMLATEIV,
                                                     ProjectedDMLATEIV)):
                                                    # these support only W but not X
                                                    W = Q
                                                    X = None
                                                    d_x = None

                                                    def fit():
                                                        return est.fit(
                                                            Y,
                                                            T,
                                                            Z=Z,
                                                            W=W,
                                                            inference=inf)

                                                    def score():
                                                        return est.score(Y,
                                                                         T,
                                                                         Z=Z,
                                                                         W=W)
                                                else:
                                                    # these support only binary, not general discrete T and Z
                                                    if discrete_t:
                                                        T = T2
                                                        d_t_final = d_t_final2

                                                    if discrete_z:
                                                        Z = Z2

                                                    if isinstance(
                                                            est,
                                                            LinearIntentToTreatDRIV
                                                    ):

                                                        def fit():
                                                            return est.fit(
                                                                Y,
                                                                T,
                                                                Z=Z,
                                                                X=X,
                                                                W=W,
                                                                inference=inf)

                                                        def score():
                                                            return est.score(
                                                                Y,
                                                                T,
                                                                Z=Z,
                                                                X=X,
                                                                W=W)
                                                    else:

                                                        def fit():
                                                            return est.fit(
                                                                Y,
                                                                T,
                                                                Z=Z,
                                                                X=X,
                                                                inference=inf)

                                                        def score():
                                                            return est.score(
                                                                Y, T, Z=Z, X=X)

                                                marginal_effect_shape = marg_eff_shape(
                                                    n, d_y, d_t_final)
                                                const_marginal_effect_shape = const_marg_eff_shape(
                                                    n, d_x, d_y, d_t_final)

                                                fit()

                                                # ensure we can serialize fit estimator
                                                pickle.dumps(est)

                                                # make sure we can call the marginal_effect and effect methods
                                                const_marg_eff = est.const_marginal_effect(
                                                    X)
                                                marg_eff = est.marginal_effect(
                                                    T, X)
                                                self.assertEqual(
                                                    shape(marg_eff),
                                                    marginal_effect_shape)
                                                self.assertEqual(
                                                    shape(const_marg_eff),
                                                    const_marginal_effect_shape
                                                )

                                                np.testing.assert_array_equal(
                                                    marg_eff
                                                    if d_x else marg_eff[0:1],
                                                    const_marg_eff)

                                                T0 = np.full_like(
                                                    T, 'a'
                                                ) if discrete_t else np.zeros_like(
                                                    T)
                                                eff = est.effect(X,
                                                                 T0=T0,
                                                                 T1=T)
                                                self.assertEqual(
                                                    shape(eff), effect_shape)

                                                # TODO: add tests for extra properties like coef_ where they exist

                                                if inf is not None:
                                                    const_marg_eff_int = est.const_marginal_effect_interval(
                                                        X)
                                                    marg_eff_int = est.marginal_effect_interval(
                                                        T, X)
                                                    self.assertEqual(
                                                        shape(marg_eff_int),
                                                        (2, ) +
                                                        marginal_effect_shape)
                                                    self.assertEqual(
                                                        shape(
                                                            const_marg_eff_int
                                                        ), (2, ) +
                                                        const_marginal_effect_shape
                                                    )
                                                    self.assertEqual(
                                                        shape(
                                                            est.
                                                            effect_interval(
                                                                X, T0=T0,
                                                                T1=T)),
                                                        (2, ) + effect_shape)

                                                # TODO: add tests for extra properties like coef_ where they exist

                                                score()

                                                # make sure we can call effect with implied scalar treatments,
                                                # no matter the dimensions of T, and also that we warn when there
                                                # are multiple treatments
                                                if d_t > 1:
                                                    cm = self.assertWarns(
                                                        Warning)
                                                else:
                                                    # ExitStack can be used as a "do nothing" ContextManager
                                                    cm = ExitStack()
                                                with cm:
                                                    effect_shape2 = (
                                                        n if d_x else 1, ) + (
                                                            (d_y, )
                                                            if d_y > 0 else ())
                                                    eff = est.effect(
                                                        X
                                                    ) if not discrete_t else est.effect(
                                                        X, T0='a', T1='b')
                                                    self.assertEqual(
                                                        shape(eff),
                                                        effect_shape2)