def test_can_custom_splitter(self): # test that we can fit with a KFold instance dml = LinearDMLCateEstimator(LinearRegression(), LogisticRegression(C=1000), discrete_treatment=True, n_splits=KFold()) dml.fit(np.array([1, 2, 3, 1, 2, 3]), np.array([1, 2, 3, 1, 2, 3]), np.ones((6, 1))) dml.score(np.array([1, 2, 3, 1, 2, 3]), np.array([1, 2, 3, 1, 2, 3]), np.ones((6, 1))) # test that we can fit with a train/test iterable dml = LinearDMLCateEstimator(LinearRegression(), LogisticRegression(C=1000), discrete_treatment=True, n_splits=[([0, 1, 2], [3, 4, 5])]) dml.fit(np.array([1, 2, 3, 1, 2, 3]), np.array([1, 2, 3, 1, 2, 3]), np.ones((6, 1))) dml.score(np.array([1, 2, 3, 1, 2, 3]), np.array([1, 2, 3, 1, 2, 3]), np.ones((6, 1)))
def test_can_use_vectors(self): """Test that we can pass vectors for T and Y (not only 2-dimensional arrays).""" dml = LinearDMLCateEstimator(LinearRegression(), LinearRegression(), featurizer=FunctionTransformer()) dml.fit(np.array([1, 2, 3, 1, 2, 3]), np.array([1, 2, 3, 1, 2, 3]), np.ones((6, 1))) self.assertAlmostEqual(dml.coef_.reshape(())[()], 1) score = dml.score(np.array([1, 2, 3, 1, 2, 3]), np.array([1, 2, 3, 1, 2, 3]), np.ones((6, 1))) self.assertAlmostEqual(score, 0)
def test_discrete_treatments(self): """Test that we can use discrete treatments""" dml = LinearDMLCateEstimator(LinearRegression(), LogisticRegression(C=1000), featurizer=FunctionTransformer(), discrete_treatment=True) # create a simple artificial setup where effect of moving from treatment # 1 -> 2 is 2, # 1 -> 3 is 1, and # 2 -> 3 is -1 (necessarily, by composing the previous two effects) # Using an uneven number of examples from different classes, # and having the treatments in non-lexicographic order, # Should rule out some basic issues. dml.fit(np.array([2, 3, 1, 3, 2, 1, 1, 1]), np.array([3, 2, 1, 2, 3, 1, 1, 1]), np.ones((8, 1))) np.testing.assert_almost_equal( dml.effect(np.ones((9, 1)), np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]), np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])), [0, 2, 1, -2, 0, -1, -1, 1, 0]) dml.score(np.array([2, 3, 1, 3, 2, 1, 1, 1]), np.array([3, 2, 1, 2, 3, 1, 1, 1]), np.ones((8, 1)))