예제 #1
0
파일: test_model.py 프로젝트: dssg/UPSG
    def test_grid_search(self):
        """

        Simulates behavior of example in:
        http://scikit-learn.org/stable/modules/generated/sklearn.grid_search.GridSearchCV.html#sklearn.grid_search.GridSearchCV

        """
        folds = 2

        parameters = {
            'kernel': (
                'rbf',
                'linear'),
            'C': [
                1,
                10,
                100],
            'random_state': [0]}
        iris = datasets.load_iris()
        iris_data = iris.data
        iris_target = iris.target

        p = Pipeline()

        node_data = p.add(NumpyRead(iris_data))
        node_target = p.add(NumpyRead(iris_target))
        node_split = p.add(SplitTrainTest(2, random_state=1))
        node_search = p.add(GridSearch(
            wrap(SVC), 
            parameters, 
            'score', 
            cv_stage_kwargs={'n_folds': folds}))
        node_params_out = p.add(CSVWrite(self._tmp_files.get('out.csv')))

        node_data['output'] > node_split['input0']
        node_target['output'] > node_split['input1']
        node_split['train0'] > node_search['X_train']
        node_split['train1'] > node_search['y_train']
        node_split['test0'] > node_search['X_test']
        node_split['test1'] > node_search['y_test']
        node_search['params_out'] > node_params_out['input']

        self.run_pipeline(p)

        result = self._tmp_files.csv_read('out.csv')

        ctrl_X_train, _, ctrl_y_train, _ = train_test_split(
            iris_data, iris_target, random_state=1)
        ctrl_cv = SKKFold(ctrl_y_train.size, folds)
        ctrl_search = grid_search.GridSearchCV(SVC(), parameters, cv=ctrl_cv)
        ctrl_search.fit(ctrl_X_train, ctrl_y_train)
        control = ctrl_search.best_params_

        # TODO a number of configurations tie here, and sklearn picks a different
        # best configuration than upsg does (although they have the same score)
        # ideally, we want to find some parameters where there is a clear 
        # winner
        control = {'C': 10, 'kernel': 'linear', 'random_state': 0}

        self.assertEqual(np_sa_to_dict(np.array([result])), control)
예제 #2
0
파일: cleanup.py 프로젝트: macressler/UPSG
def run(path='.'):
    """removes .upsg files and temporary sql tables from
    the given path"""

    for file in glob.iglob(os.path.join(path, '*.upsg')):
        hfile = tables.open_file(file, mode='r')
        storage_method = hfile.get_node_attr('/upsg_inf', 'storage_method')
        if storage_method == 'sql':
            sql_group = hfile.root.sql
            pipeline_generated = hfile.get_node_attr(sql_group,
                                                     'pipeline_generated')
            if pipeline_generated:
                db_url = hfile.get_node_attr(sql_group, 'db_url')
                tbl_name = hfile.get_node_attr(sql_group, 'tbl_name')
                conn_params = np_sa_to_dict(hfile.root.sql.conn_params.read())
                engine = sqlalchemy.create_engine(db_url)
                conn = engine.connect(**conn_params)
                md = sqlalchemy.MetaData()
                md.reflect(conn)
                tbl = md.tables[tbl_name]
                tbl.drop(conn)
        hfile.close()
        os.remove(file)
예제 #3
0
파일: cleanup.py 프로젝트: dssg/UPSG
def run(path='.'):
    """removes .upsg files and temporary sql tables from
    the given path"""

    for file in glob.iglob(os.path.join(path, '*.upsg')):
        hfile = tables.open_file(file, mode='r')
        storage_method = hfile.get_node_attr('/upsg_inf', 'storage_method')
        if storage_method == 'sql':
            sql_group = hfile.root.sql
            pipeline_generated = hfile.get_node_attr(sql_group,
                                                     'pipeline_generated')
            if pipeline_generated:
                db_url = hfile.get_node_attr(sql_group, 'db_url')
                tbl_name = hfile.get_node_attr(sql_group, 'tbl_name')
                conn_params = np_sa_to_dict(hfile.root.sql.conn_params.read())
                engine = sqlalchemy.create_engine(db_url)
                conn = engine.connect(**conn_params)
                md = sqlalchemy.MetaData()
                md.reflect(conn)
                tbl = md.tables[tbl_name]
                tbl.drop(conn)
        hfile.close()
        os.remove(file)