Ejemplo n.º 1
0
 def test_can_use_interpreters(self):
     n = 100
     for t_shape in [(n, ), (n, 1)]:
         for y_shape in [(n, ), (n, 1)]:
             X = np.random.normal(size=(n, 4))
             T = np.random.binomial(1, 0.5, size=t_shape)
             Y = (T.flatten() * (2 * (X[:, 0] > 0) - 1)).reshape(y_shape)
             est = LinearDML(model_y=LinearRegression(),
                             model_t=LogisticRegression(),
                             discrete_treatment=True)
             est.fit(Y, T, X=X)
             for intrp in [
                     SingleTreeCateInterpreter(),
                     SingleTreePolicyInterpreter()
             ]:
                 with self.subTest(t_shape=t_shape,
                                   y_shape=y_shape,
                                   intrp=intrp):
                     with self.assertRaises(Exception):
                         # prior to calling interpret, can't plot, render, etc.
                         intrp.plot()
                     intrp.interpret(est, X)
                     intrp.plot()
                     intrp.render('tmp.pdf', view=False)
                     intrp.export_graphviz()
Ejemplo n.º 2
0
    def test_can_assign_treatment(self):
        n = 100
        X = np.random.normal(size=(n, 4))
        T = np.random.binomial(1, 0.5, size=(n,))
        Y = np.random.normal(size=(n,))
        est = LinearDML(discrete_treatment=True)
        est.fit(Y, T, X=X)

        # can interpret without uncertainty
        intrp = SingleTreePolicyInterpreter()
        with self.assertRaises(Exception):
            # can't treat before interpreting
            intrp.treat(X)

        intrp.interpret(est, X)
        T_policy = intrp.treat(X)
        assert T.shape == T_policy.shape
Ejemplo n.º 3
0
 def test_can_use_interpreters(self):
     n = 100
     for t_shape in [(n, ), (n, 1)]:
         for y_shape in [(n, ), (n, 1)]:
             X = np.random.normal(size=(n, 4))
             T = np.random.binomial(1, 0.5, size=t_shape)
             Y = np.random.normal(size=y_shape)
             est = LinearDMLCateEstimator(discrete_treatment=True)
             est.fit(Y, T, X)
             for intrp in [
                     SingleTreeCateInterpreter(),
                     SingleTreePolicyInterpreter()
             ]:
                 with self.subTest(t_shape=t_shape,
                                   y_shape=y_shape,
                                   intrp=intrp):
                     with self.assertRaises(Exception):
                         # prior to calling interpret, can't plot, render, etc.
                         intrp.plot()
                     intrp.interpret(est, X)
                     intrp.plot()
                     intrp.render('tmp.pdf', view=False)
                     intrp.export_graphviz()
Ejemplo n.º 4
0
    def test_can_assign_treatment(self):
        n = 100
        X = np.random.normal(size=(n, 4))
        T = np.random.binomial(1, 0.5, size=(n, ))
        Y = (2 * (X[:, 0] > 0) - 1) * T.flatten()
        est = LinearDML(model_y=LinearRegression(),
                        model_t=LogisticRegression(),
                        discrete_treatment=True)
        est.fit(Y, T, X=X)

        # can interpret without uncertainty
        intrp = SingleTreePolicyInterpreter()
        with self.assertRaises(Exception):
            # can't treat before interpreting
            intrp.treat(X)

        intrp.interpret(est, X)
        T_policy = intrp.treat(X)
        assert T.shape == T_policy.shape
        intrp.interpret(est,
                        X,
                        sample_treatment_costs=np.ones((T.shape[0], 1)))
        T_policy = intrp.treat(X)
        assert T.shape == T_policy.shape
        with np.testing.assert_raises(ValueError):
            intrp.interpret(est,
                            X,
                            sample_treatment_costs=np.ones((T.shape[0], 2)))
Ejemplo n.º 5
0
    def test_random_cate_settings(self):
        """Verify that we can call methods on the CATE interpreter with various combinations of inputs"""
        n = 100
        for _ in range(100):
            t_shape = (n, ) if self.coinflip() else (n, 1)
            y_shape = (n, ) if self.coinflip() else (n, 1)
            discrete_t = self.coinflip()
            X = np.random.normal(size=(n, 4))
            X2 = np.random.normal(size=(10, 4))
            T = np.random.binomial(
                2, 0.5, size=t_shape) if discrete_t else np.random.normal(
                    size=t_shape)
            Y = ((T.flatten() == 1) * (2 * (X[:, 0] > 0) - 1) +
                 (T.flatten() == 2) * (2 * (X[:, 1] > 0) - 1)).reshape(y_shape)

            if self.coinflip():
                y_shape = (n, 2)
                Y = np.tile(Y.reshape((-1, 1)), (1, 2))

            est = LinearDML(model_y=LinearRegression(),
                            model_t=LogisticRegression()
                            if discrete_t else LinearRegression(),
                            discrete_treatment=discrete_t)

            fit_kwargs = {}
            cate_init_kwargs = {}
            policy_init_kwargs = {}
            intrp_kwargs = {}
            policy_intrp_kwargs = {}
            common_kwargs = {}
            plot_kwargs = {}
            render_kwargs = {}
            export_kwargs = {}

            if self.coinflip():
                cate_init_kwargs.update(include_model_uncertainty=True)
                policy_init_kwargs.update(risk_level=0.1)
            else:
                fit_kwargs.update(inference=None)

            if self.coinflip():
                cate_init_kwargs.update(uncertainty_level=0.01)

            if self.coinflip():
                policy_init_kwargs.update(risk_seeking=True)

            if self.coinflip(1 / 3):
                policy_intrp_kwargs.update(sample_treatment_costs=0.1)
            elif self.coinflip():
                if discrete_t:
                    policy_intrp_kwargs.update(
                        sample_treatment_costs=np.random.normal(size=(10, 2)))
                else:
                    if self.coinflip():
                        policy_intrp_kwargs.update(
                            sample_treatment_costs=np.random.normal(size=(10,
                                                                          1)))
                    else:
                        policy_intrp_kwargs.update(
                            sample_treatment_costs=np.random.normal(
                                size=(10, )))

            if self.coinflip():
                common_kwargs.update(feature_names=['A', 'B', 'C', 'D'])

            if self.coinflip():
                common_kwargs.update(filled=False)

            if self.coinflip():
                common_kwargs.update(rounded=False)

            if self.coinflip():
                common_kwargs.update(precision=1)

            if self.coinflip():
                render_kwargs.update(rotate=True)
                export_kwargs.update(rotate=True)

            if self.coinflip():
                render_kwargs.update(leaves_parallel=False)
                export_kwargs.update(leaves_parallel=False)
                if discrete_t:
                    render_kwargs.update(treatment_names=[
                        'control gp', 'treated gp', 'more gp'
                    ])
                    export_kwargs.update(treatment_names=[
                        'control gp', 'treated gp', 'more gp'
                    ])
                else:
                    render_kwargs.update(
                        treatment_names=['control gp', 'treated gp'])
                    export_kwargs.update(
                        treatment_names=['control gp', 'treated gp'])

            if self.coinflip():
                render_kwargs.update(format='png')

            if self.coinflip():
                export_kwargs.update(out_file='out')

            if self.coinflip(0.95):  # don't launch files most of the time
                render_kwargs.update(view=False)

            with self.subTest(t_shape=t_shape,
                              y_shape=y_shape,
                              discrete_t=discrete_t,
                              fit_kwargs=fit_kwargs,
                              cate_init_kwargs=cate_init_kwargs,
                              policy_init_kwargs=policy_init_kwargs,
                              policy_intrp_kwargs=policy_intrp_kwargs,
                              intrp_kwargs=intrp_kwargs,
                              common_kwargs=common_kwargs,
                              plot_kwargs=plot_kwargs,
                              render_kwargs=render_kwargs,
                              export_kwargs=export_kwargs):
                plot_kwargs.update(common_kwargs)
                render_kwargs.update(common_kwargs)
                export_kwargs.update(common_kwargs)
                policy_intrp_kwargs.update(intrp_kwargs)

                est.fit(Y, T, X=X, **fit_kwargs)

                intrp = SingleTreeCateInterpreter(**cate_init_kwargs)
                intrp.interpret(est, X2, **intrp_kwargs)
                intrp.plot(**plot_kwargs)
                intrp.render('outfile', **render_kwargs)
                intrp.export_graphviz(**export_kwargs)

                intrp = SingleTreePolicyInterpreter(**policy_init_kwargs)
                try:
                    intrp.interpret(est, X2, **policy_intrp_kwargs)
                    intrp.plot(**plot_kwargs)
                    intrp.render('outfile', **render_kwargs)
                    intrp.export_graphviz(**export_kwargs)
                except AttributeError as e:
                    assert str(e).find("samples should") >= 0