예제 #1
0
 def test_pipeline(self):
     dataset = MockDataset()
     estimator0 = MockEstimator()
     transformer1 = MockTransformer()
     estimator2 = MockEstimator()
     transformer3 = MockTransformer()
     pipeline = Pipeline(
         stages=[estimator0, transformer1, estimator2, transformer3])
     pipeline_model = pipeline.fit(dataset, {
         estimator0.fake: 0,
         transformer1.fake: 1
     })
     model0, transformer1, model2, transformer3 = pipeline_model.stages
     self.assertEqual(0, model0.dataset_index)
     self.assertEqual(0, model0.getFake())
     self.assertEqual(1, transformer1.dataset_index)
     self.assertEqual(1, transformer1.getFake())
     self.assertEqual(2, dataset.index)
     self.assertIsNone(model2.dataset_index,
                       "The last model shouldn't be called in fit.")
     self.assertIsNone(transformer3.dataset_index,
                       "The last transformer shouldn't be called in fit.")
     dataset = pipeline_model.transform(dataset)
     self.assertEqual(2, model0.dataset_index)
     self.assertEqual(3, transformer1.dataset_index)
     self.assertEqual(4, model2.dataset_index)
     self.assertEqual(5, transformer3.dataset_index)
     self.assertEqual(6, dataset.index)
예제 #2
0
 def testDefaultFitMultiple(self):
     N = 4
     data = MockDataset()
     estimator = MockEstimator()
     params = [{estimator.fake: i} for i in range(N)]
     modelIter = estimator.fitMultiple(data, params)
     indexList = []
     for index, model in modelIter:
         self.assertEqual(model.getFake(), index)
         indexList.append(index)
     self.assertEqual(sorted(indexList), list(range(N)))
예제 #3
0
파일: test_base.py 프로젝트: Brett-A/spark
 def testDefaultFitMultiple(self):
     N = 4
     data = MockDataset()
     estimator = MockEstimator()
     params = [{estimator.fake: i} for i in range(N)]
     modelIter = estimator.fitMultiple(data, params)
     indexList = []
     for index, model in modelIter:
         self.assertEqual(model.getFake(), index)
         indexList.append(index)
     self.assertEqual(sorted(indexList), list(range(N)))
예제 #4
0
파일: test_base.py 프로젝트: zoelin7/spark
class EstimatorTest(unittest.TestCase):
    def setUp(self):
        self.estimator = MockEstimator()
        self.data = MockDataset()

    def test_fit_invalid_params(self):
        invalid_type_parms = ""
        self.assertRaises(TypeError, self.estimator.fit, self.data, invalid_type_parms)

    def testDefaultFitMultiple(self):
        N = 4
        params = [{self.estimator.fake: i} for i in range(N)]
        modelIter = self.estimator.fitMultiple(self.data, params)
        indexList = []
        for index, model in modelIter:
            self.assertEqual(model.getFake(), index)
            indexList.append(index)
        self.assertEqual(sorted(indexList), list(range(N)))
예제 #5
0
 def setUp(self):
     self.estimator = MockEstimator()
     self.data = MockDataset()