Esempio n. 1
0
    def test_Pipegraph__filter_nodes_predict(self):
        alternative_connections = {'Regressor': dict(X='X', y='y')}

        pgraph = PipeGraph(steps=self.steps,
                           fit_connections=self.connections,
                           predict_connections=alternative_connections)
        pgraph.fit(self.X, self.y)
        predict_nodes = list(pgraph._filter_predict_nodes())
        self.assertEqual(predict_nodes, ['Regressor'])
Esempio n. 2
0
 def test_Pipegraph__predict_connections(self):
     pgraph = PipeGraph(self.steps, self.connections)
     pgraph.fit(self.X, self.y)
     predict_nodes_list = list(pgraph._filter_predict_nodes())
     self.assertEqual(
         sorted(predict_nodes_list),
         sorted([
             'Concatenate_Xy',
             'Gaussian_Mixture',
             'Dbscan',
             'Combine_Clustering',
             'Regressor',
         ]))
Esempio n. 3
0
    def test_Pipegraph__some_predict_connections(self):
        some_connections = {
            'Concatenate_Xy': dict(df1='X', df2='y'),
            'Gaussian_Mixture': dict(X=('Concatenate_Xy', 'predict')),
            'Dbscan': dict(X=('Concatenate_Xy', 'predict')),
        }

        pgraph = PipeGraph(steps=self.steps,
                           fit_connections=self.connections,
                           predict_connections=some_connections)
        pgraph.fit(self.X, self.y)
        predict_nodes_list = list(pgraph._filter_predict_nodes())
        self.assertEqual(
            sorted(predict_nodes_list),
            sorted([
                'Concatenate_Xy',
                'Gaussian_Mixture',
                'Dbscan',
            ]))