示例#1
0
 def test_orf(self):
     # Single outcome only, ORF does not support multiple outcomes
     X = TestPandasIntegration.df[TestPandasIntegration.features]
     W = TestPandasIntegration.df[TestPandasIntegration.controls]
     Y = TestPandasIntegration.df[TestPandasIntegration.outcome]
     T = TestPandasIntegration.df[TestPandasIntegration.cont_treat]
     # Test DMLOrthoForest
     est = DMLOrthoForest(n_trees=100,
                          max_depth=2,
                          model_T=WeightedLasso(),
                          model_Y=WeightedLasso())
     est.fit(Y, T, X=X, W=W, inference='blb')
     treatment_effects = est.effect(X)
     lb, ub = est.effect_interval(X, alpha=0.05)
     self._check_popsum_names(est.effect_inference(X).population_summary())
     # Test DROrthoForest
     est = DROrthoForest(n_trees=100, max_depth=2)
     T = TestPandasIntegration.df[TestPandasIntegration.bin_treat]
     est.fit(Y, T, X=X, W=W, inference='blb')
     treatment_effects = est.effect(X)
     lb, ub = est.effect_interval(X, alpha=0.05)
     self._check_popsum_names(est.effect_inference(X).population_summary())
示例#2
0
    def test_effect_shape(self):
        import scipy.special
        np.random.seed(123)
        n = 40  # number of raw samples
        d = 4  # number of binary features + 1

        # Generating random segments aka binary features. We will use features 0,...,3 for heterogeneity.
        # The rest for controls. Just as an example.
        X = np.random.binomial(1, .5, size=(n, d))
        # Generating A/B test data
        T = np.random.binomial(2, .5, size=(n, ))
        # Generating an outcome with treatment effect heterogeneity. The first binary feature creates heterogeneity
        # We also have confounding on the first variable. We also have heteroskedastic errors.
        y = (-1 + 2 * X[:, 0]) * T + X[:, 0] + (
            1 * X[:, 0] + 1) * np.random.normal(0, 1, size=(n, ))
        from sklearn.dummy import DummyClassifier, DummyRegressor
        est = DROrthoForest(n_trees=10,
                            model_Y=DummyRegressor(strategy='mean'),
                            propensity_model=DummyClassifier(strategy='prior'),
                            n_jobs=1)
        est.fit(y, T, X=X)
        assert est.const_marginal_effect(
            X[:3]).shape == (3, 2), "Const Marginal Effect dimension incorrect"
        assert est.marginal_effect(
            1, X[:3]).shape == (3, 2), "Marginal Effect dimension incorrect"
        assert est.effect(X[:3]).shape == (3, ), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=0,
                          T1=2).shape == (3, ), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=1,
                          T1=2).shape == (3, ), "Effect dimension incorrect"
        lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
        assert lb.shape == (3, ), "Effect interval dimension incorrect"
        lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
        assert lb.shape == (3, ), "Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_interval(X[:3])
        assert lb.shape == (
            3, 2), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
        assert lb.shape == (
            3, 2), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_interval(1, X[:3])
        assert lb.shape == (3,
                            2), "Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
        assert lb.shape == (3,
                            2), "Marginal Effect interval dimension incorrect"
        est.fit(y.reshape(-1, 1), T, X=X)
        assert est.const_marginal_effect(
            X[:3]).shape == (3, 1,
                             2), "Const Marginal Effect dimension incorrect"
        assert est.marginal_effect(
            1, X[:3]).shape == (3, 1, 2), "Marginal Effect dimension incorrect"
        assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=0,
                          T1=2).shape == (3, 1), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=1,
                          T1=2).shape == (3, 1), "Effect dimension incorrect"
        lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
        assert lb.shape == (3, 1), "Effect interval dimension incorrect"
        lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
        assert lb.shape == (3, 1), "Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_interval(X[:3])
        assert lb.shape == (
            3, 1, 2), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
        assert lb.shape == (
            3, 1, 2), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_interval(1, X[:3])
        assert lb.shape == (3, 1,
                            2), "Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
        assert lb.shape == (3, 1,
                            2), "Marginal Effect interval dimension incorrect"
        # Test causal foret API
        est = CausalForest(n_trees=10,
                           model_Y=DummyRegressor(strategy='mean'),
                           model_T=DummyClassifier(strategy='prior'),
                           discrete_treatment=True,
                           n_jobs=1)
        est.fit(y, T, X=X)
        assert est.const_marginal_effect(
            X[:3]).shape == (3, 2), "Const Marginal Effect dimension incorrect"
        assert est.marginal_effect(
            1, X[:3]).shape == (3, 2), "Marginal Effect dimension incorrect"
        assert est.effect(X[:3]).shape == (3, ), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=0,
                          T1=2).shape == (3, ), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=1,
                          T1=2).shape == (3, ), "Effect dimension incorrect"
        lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
        assert lb.shape == (3, ), "Effect interval dimension incorrect"
        lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
        assert lb.shape == (3, ), "Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_interval(X[:3])
        assert lb.shape == (
            3, 2), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
        assert lb.shape == (
            3, 2), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_interval(1, X[:3])
        assert lb.shape == (3,
                            2), "Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
        assert lb.shape == (3,
                            2), "Marginal Effect interval dimension incorrect"
        est.fit(y.reshape(-1, 1), T, X=X)
        assert est.const_marginal_effect(
            X[:3]).shape == (3, 1,
                             2), "Const Marginal Effect dimension incorrect"
        assert est.marginal_effect(
            1, X[:3]).shape == (3, 1, 2), "Marginal Effect dimension incorrect"
        assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=0,
                          T1=2).shape == (3, 1), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=1,
                          T1=2).shape == (3, 1), "Effect dimension incorrect"
        lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
        assert lb.shape == (3, 1), "Effect interval dimension incorrect"
        lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
        assert lb.shape == (3, 1), "Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_interval(X[:3])
        assert lb.shape == (
            3, 1, 2), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
        assert lb.shape == (
            3, 1, 2), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_interval(1, X[:3])
        assert lb.shape == (3, 1,
                            2), "Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
        assert lb.shape == (3, 1,
                            2), "Marginal Effect interval dimension incorrect"

        from sklearn.dummy import DummyClassifier, DummyRegressor
        for global_residualization in [False, True]:
            est = DMLOrthoForest(n_trees=10,
                                 model_Y=DummyRegressor(strategy='mean'),
                                 model_T=DummyRegressor(strategy='mean'),
                                 global_residualization=global_residualization,
                                 n_jobs=1)
            est.fit(y.reshape(-1, 1), T.reshape(-1, 1), X=X)
            assert est.const_marginal_effect(X[:3]).shape == (
                3, 1, 1), "Const Marginal Effect dimension incorrect"
            assert est.marginal_effect(
                1, X[:3]).shape == (3, 1,
                                    1), "Marginal Effect dimension incorrect"
            assert est.effect(X[:3]).shape == (3,
                                               1), "Effect dimension incorrect"
            assert est.effect(X[:3], T0=0,
                              T1=2).shape == (3,
                                              1), "Effect dimension incorrect"
            assert est.effect(X[:3], T0=1,
                              T1=2).shape == (3,
                                              1), "Effect dimension incorrect"
            lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
            assert lb.shape == (3, 1), "Effect interval dimension incorrect"
            lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
            assert lb.shape == (3, 1), "Effect interval dimension incorrect"
            lb, _ = est.const_marginal_effect_interval(X[:3])
            assert lb.shape == (
                3, 1, 1), "Const Marginal Effect interval dimension incorrect"
            lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
            assert lb.shape == (
                3, 1, 1), "Const Marginal Effect interval dimension incorrect"
            lb, _ = est.marginal_effect_interval(1, X[:3])
            assert lb.shape == (
                3, 1, 1), "Marginal Effect interval dimension incorrect"
            lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
            assert lb.shape == (
                3, 1, 1), "Marginal Effect interval dimension incorrect"
            est.fit(y.reshape(-1, 1), T, X=X)
            assert est.const_marginal_effect(X[:3]).shape == (
                3, 1), "Const Marginal Effect dimension incorrect"
            assert est.marginal_effect(
                1,
                X[:3]).shape == (3, 1), "Marginal Effect dimension incorrect"
            assert est.effect(X[:3]).shape == (3,
                                               1), "Effect dimension incorrect"
            assert est.effect(X[:3], T0=0,
                              T1=2).shape == (3,
                                              1), "Effect dimension incorrect"
            assert est.effect(X[:3], T0=1,
                              T1=2).shape == (3,
                                              1), "Effect dimension incorrect"
            lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
            assert lb.shape == (3, 1), "Effect interval dimension incorrect"
            lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
            assert lb.shape == (3, 1), "Effect interval dimension incorrect"
            lb, _ = est.const_marginal_effect_interval(X[:3])
            print(lb.shape)
            assert lb.shape == (
                3, 1), "Const Marginal Effect interval dimension incorrect"
            lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
            assert lb.shape == (
                3, 1), "Const Marginal Effect interval dimension incorrect"
            lb, _ = est.marginal_effect_interval(1, X[:3])
            assert lb.shape == (
                3, 1), "Marginal Effect interval dimension incorrect"
            lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
            assert lb.shape == (
                3, 1), "Marginal Effect interval dimension incorrect"
            est.fit(y, T, X=X)
            assert est.const_marginal_effect(X[:3]).shape == (
                3, ), "Const Marginal Effect dimension incorrect"
            assert est.marginal_effect(
                1, X[:3]).shape == (3, ), "Marginal Effect dimension incorrect"
            assert est.effect(
                X[:3]).shape == (3, ), "Effect dimension incorrect"
            assert est.effect(
                X[:3], T0=0, T1=2).shape == (3, ), "Effect dimension incorrect"
            assert est.effect(
                X[:3], T0=1, T1=2).shape == (3, ), "Effect dimension incorrect"
            lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
            assert lb.shape == (3, ), "Effect interval dimension incorrect"
            lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
            assert lb.shape == (3, ), "Effect interval dimension incorrect"
            lb, _ = est.const_marginal_effect_interval(X[:3])
            assert lb.shape == (
                3, ), "Const Marginal Effect interval dimension incorrect"
            lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
            assert lb.shape == (
                3, ), "Const Marginal Effect interval dimension incorrect"
            lb, _ = est.marginal_effect_interval(1, X[:3])
            assert lb.shape == (
                3, ), "Marginal Effect interval dimension incorrect"
            lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
            assert lb.shape == (
                3, ), "Marginal Effect interval dimension incorrect"

        # Test Causal Forest API
        est = CausalForest(n_trees=10,
                           model_Y=DummyRegressor(strategy='mean'),
                           model_T=DummyRegressor(strategy='mean'),
                           n_jobs=1)
        est.fit(y.reshape(-1, 1), T.reshape(-1, 1), X=X)
        assert est.const_marginal_effect(
            X[:3]).shape == (3, 1,
                             1), "Const Marginal Effect dimension incorrect"
        assert est.marginal_effect(
            1, X[:3]).shape == (3, 1, 1), "Marginal Effect dimension incorrect"
        assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=0,
                          T1=2).shape == (3, 1), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=1,
                          T1=2).shape == (3, 1), "Effect dimension incorrect"
        lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
        assert lb.shape == (3, 1), "Effect interval dimension incorrect"
        lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
        assert lb.shape == (3, 1), "Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_interval(X[:3])
        assert lb.shape == (
            3, 1, 1), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
        assert lb.shape == (
            3, 1, 1), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_interval(1, X[:3])
        assert lb.shape == (3, 1,
                            1), "Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
        assert lb.shape == (3, 1,
                            1), "Marginal Effect interval dimension incorrect"
        est.fit(y.reshape(-1, 1), T, X=X)
        assert est.const_marginal_effect(
            X[:3]).shape == (3, 1), "Const Marginal Effect dimension incorrect"
        assert est.marginal_effect(
            1, X[:3]).shape == (3, 1), "Marginal Effect dimension incorrect"
        assert est.effect(X[:3]).shape == (3, 1), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=0,
                          T1=2).shape == (3, 1), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=1,
                          T1=2).shape == (3, 1), "Effect dimension incorrect"
        lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
        assert lb.shape == (3, 1), "Effect interval dimension incorrect"
        lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
        assert lb.shape == (3, 1), "Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_interval(X[:3])
        assert lb.shape == (
            3, 1), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
        assert lb.shape == (
            3, 1), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_interval(1, X[:3])
        assert lb.shape == (3,
                            1), "Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
        assert lb.shape == (3,
                            1), "Marginal Effect interval dimension incorrect"
        est.fit(y, T, X=X)
        assert est.const_marginal_effect(
            X[:3]).shape == (3, ), "Const Marginal Effect dimension incorrect"
        assert est.marginal_effect(
            1, X[:3]).shape == (3, ), "Marginal Effect dimension incorrect"
        assert est.effect(X[:3]).shape == (3, ), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=0,
                          T1=2).shape == (3, ), "Effect dimension incorrect"
        assert est.effect(X[:3], T0=1,
                          T1=2).shape == (3, ), "Effect dimension incorrect"
        lb, _ = est.effect_interval(X[:3], T0=1, T1=2)
        assert lb.shape == (3, ), "Effect interval dimension incorrect"
        lb, _ = est.effect_inference(X[:3], T0=1, T1=2).conf_int()
        assert lb.shape == (3, ), "Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_interval(X[:3])
        assert lb.shape == (
            3, ), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.const_marginal_effect_inference(X[:3]).conf_int()
        assert lb.shape == (
            3, ), "Const Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_interval(1, X[:3])
        assert lb.shape == (
            3, ), "Marginal Effect interval dimension incorrect"
        lb, _ = est.marginal_effect_inference(1, X[:3]).conf_int()
        assert lb.shape == (
            3, ), "Marginal Effect interval dimension incorrect"
示例#3
0
    def test_binary_treatments(self):
        np.random.seed(123)
        # Generate data with binary treatments
        log_odds = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_T) + \
            TestOrthoForest.eta_sample(TestOrthoForest.n)
        T_sigmoid = 1 / (1 + np.exp(-log_odds))
        T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
        TE = np.array([self._exp_te(x) for x in TestOrthoForest.X])
        Y = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_Y) + \
            T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
        # Instantiate model with default params. Using n_jobs=1 since code coverage
        # does not work well with parallelism.
        est = DROrthoForest(n_trees=10,
                            n_jobs=1,
                            propensity_model=LogisticRegression(),
                            model_Y=Lasso(),
                            propensity_model_final=LogisticRegressionCV(
                                penalty='l1', solver='saga'),
                            model_Y_final=WeightedLassoCVWrapper())
        # Test inputs for binary treatments
        # --> Check that one can pass in regular lists
        est.fit(list(Y),
                list(T),
                X=list(TestOrthoForest.X),
                W=list(TestOrthoForest.W))
        # --> Check that it fails correctly if lists of different shape are passed in
        self.assertRaises(ValueError, est.fit, Y[:TestOrthoForest.n // 2],
                          T[:TestOrthoForest.n // 2], TestOrthoForest.X,
                          TestOrthoForest.W)
        # --> Check that it works when T, Y have shape (n, 1)
        est.fit(Y.reshape(-1, 1),
                T.reshape(-1, 1),
                X=TestOrthoForest.X,
                W=TestOrthoForest.W)
        # --> Check that it fails correctly when T has shape (n, 2)
        self.assertRaises(ValueError, est.fit, Y,
                          np.ones((TestOrthoForest.n, 2)), TestOrthoForest.X,
                          TestOrthoForest.W)
        # --> Check that it fails correctly when the treatments are not numeric
        self.assertRaises(ValueError, est.fit, Y,
                          np.array(["a"] * TestOrthoForest.n),
                          TestOrthoForest.X, TestOrthoForest.W)
        # Check that outputs have the correct shape
        out_te = est.const_marginal_effect(TestOrthoForest.x_test)
        self.assertSequenceEqual((TestOrthoForest.x_test.shape[0], 1, 1),
                                 out_te.shape)
        # Test binary treatments with controls
        est = DROrthoForest(n_trees=100,
                            min_leaf_size=10,
                            max_depth=30,
                            subsample_ratio=0.30,
                            bootstrap=False,
                            n_jobs=1,
                            propensity_model=LogisticRegression(C=1 / 0.024,
                                                                penalty='l1',
                                                                solver='saga'),
                            model_Y=Lasso(alpha=0.024),
                            propensity_model_final=LogisticRegressionCV(
                                penalty='l1', solver='saga'),
                            model_Y_final=WeightedLassoCVWrapper())
        est.fit(Y,
                T,
                X=TestOrthoForest.X,
                W=TestOrthoForest.W,
                inference="blb")
        self._test_te(est,
                      TestOrthoForest.expected_exp_te,
                      tol=0.7,
                      treatment_type='discrete')
        self._test_ci(est,
                      TestOrthoForest.expected_exp_te,
                      tol=1.5,
                      treatment_type='discrete')
        # Test binary treatments without controls
        log_odds = TestOrthoForest.eta_sample(TestOrthoForest.n)
        T_sigmoid = 1 / (1 + np.exp(-log_odds))
        T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
        Y = T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
        est.fit(Y, T, X=TestOrthoForest.X, inference="blb")
        self._test_te(est,
                      TestOrthoForest.expected_exp_te,
                      tol=0.5,
                      treatment_type='discrete')
        self._test_ci(est,
                      TestOrthoForest.expected_exp_te,
                      tol=1.5,
                      treatment_type='discrete')

        # Test CausalForest API
        np.random.seed(123)
        # Generate data with binary treatments
        log_odds = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_T) + \
            TestOrthoForest.eta_sample(TestOrthoForest.n)
        T_sigmoid = 1 / (1 + np.exp(-log_odds))
        T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
        TE = np.array([self._exp_te(x) for x in TestOrthoForest.X])
        Y = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_Y) + \
            T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
        # Instantiate model with default params. Using n_jobs=1 since code coverage
        # does not work well with parallelism.
        est = CausalForest(n_trees=10,
                           n_jobs=-1,
                           model_Y=Lasso(),
                           model_T=LogisticRegressionCV(penalty='l1',
                                                        solver='saga'))
        # Test inputs for binary treatments
        # --> Check that one can pass in regular lists
        est.fit(list(Y),
                list(T),
                X=list(TestOrthoForest.X),
                W=list(TestOrthoForest.W))
        # --> Check that it fails correctly if lists of different shape are passed in
        self.assertRaises(ValueError, est.fit, Y[:TestOrthoForest.n // 2],
                          T[:TestOrthoForest.n // 2], TestOrthoForest.X,
                          TestOrthoForest.W)
        # --> Check that it works when T, Y have shape (n, 1)
        est.fit(Y.reshape(-1, 1),
                T.reshape(-1, 1),
                X=TestOrthoForest.X,
                W=TestOrthoForest.W)
        # --> Check that it fails correctly when T has shape (n, 2)
        self.assertRaises(ValueError, est.fit, Y,
                          np.ones((TestOrthoForest.n, 2)), TestOrthoForest.X,
                          TestOrthoForest.W)
        # --> Check that it fails correctly when the treatments are not numeric
        self.assertRaises(ValueError, est.fit, Y,
                          np.array(["a"] * TestOrthoForest.n),
                          TestOrthoForest.X, TestOrthoForest.W)
        # Check that outputs have the correct shape
        out_te = est.const_marginal_effect(TestOrthoForest.x_test)
        self.assertSequenceEqual((TestOrthoForest.x_test.shape[0], 1, 1),
                                 out_te.shape)
        # Test binary treatments with controls
        est = CausalForest(n_trees=100,
                           min_leaf_size=10,
                           max_depth=30,
                           subsample_ratio=0.30,
                           n_jobs=-1,
                           model_Y=Lasso(),
                           model_T=LogisticRegressionCV(penalty='l1',
                                                        solver='saga'),
                           discrete_treatment=True,
                           cv=5)
        est.fit(Y,
                T,
                X=TestOrthoForest.X,
                W=TestOrthoForest.W,
                inference="blb")
        self._test_te(est,
                      TestOrthoForest.expected_exp_te,
                      tol=0.7,
                      treatment_type='discrete')
        self._test_ci(est,
                      TestOrthoForest.expected_exp_te,
                      tol=1.5,
                      treatment_type='discrete')
        # Test binary treatments without controls
        log_odds = TestOrthoForest.eta_sample(TestOrthoForest.n)
        T_sigmoid = 1 / (1 + np.exp(-log_odds))
        T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
        Y = T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
        est.fit(Y, T, X=TestOrthoForest.X, inference="blb")
        self._test_te(est,
                      TestOrthoForest.expected_exp_te,
                      tol=0.5,
                      treatment_type='discrete')
        self._test_ci(est,
                      TestOrthoForest.expected_exp_te,
                      tol=1.5,
                      treatment_type='discrete')