Exemplo n.º 1
0
 def test_pipeline(self):
     dataset = MockDataset()
     estimator0 = MockEstimator()
     transformer1 = MockTransformer()
     estimator2 = MockEstimator()
     transformer3 = MockTransformer()
     pipeline = Pipeline(
         stages=[estimator0, transformer1, estimator2, transformer3])
     pipeline_model = pipeline.fit(dataset, {
         estimator0.fake: 0,
         transformer1.fake: 1
     })
     model0, transformer1, model2, transformer3 = pipeline_model.stages
     self.assertEqual(0, model0.dataset_index)
     self.assertEqual(0, model0.getFake())
     self.assertEqual(1, transformer1.dataset_index)
     self.assertEqual(1, transformer1.getFake())
     self.assertEqual(2, dataset.index)
     self.assertIsNone(model2.dataset_index,
                       "The last model shouldn't be called in fit.")
     self.assertIsNone(transformer3.dataset_index,
                       "The last transformer shouldn't be called in fit.")
     dataset = pipeline_model.transform(dataset)
     self.assertEqual(2, model0.dataset_index)
     self.assertEqual(3, transformer1.dataset_index)
     self.assertEqual(4, model2.dataset_index)
     self.assertEqual(5, transformer3.dataset_index)
     self.assertEqual(6, dataset.index)
Exemplo n.º 2
0
 def testDefaultFitMultiple(self):
     N = 4
     data = MockDataset()
     estimator = MockEstimator()
     params = [{estimator.fake: i} for i in range(N)]
     modelIter = estimator.fitMultiple(data, params)
     indexList = []
     for index, model in modelIter:
         self.assertEqual(model.getFake(), index)
         indexList.append(index)
     self.assertEqual(sorted(indexList), list(range(N)))
Exemplo n.º 3
0
    def test_identity_pipeline(self):
        dataset = MockDataset()

        def doTransform(pipeline):
            pipeline_model = pipeline.fit(dataset)
            return pipeline_model.transform(dataset)

        # check that empty pipeline did not perform any transformation
        self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
        # check that failure to set stages param will raise KeyError for missing param
        self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
Exemplo n.º 4
0
 def setUp(self):
     self.estimator = MockEstimator()
     self.data = MockDataset()
Exemplo n.º 5
0
 def test_transform_invalid_type(self):
     transformer = MockTransformer()
     data = MockDataset()
     self.assertRaises(TypeError, transformer.transform, data, "")