예제 #1
0
파일: views.py 프로젝트: surlybot/xcessiv
def confirm_base_learner_origin(id):
    path = functions.get_path_from_query_string(request)

    with functions.DBContextManager(path) as session:
        base_learner_origin = session.query(models.BaseLearnerOrigin).filter_by(id=id).first()
        if base_learner_origin is None:
            raise exceptions.UserError('Base learner origin {} not found'.format(id), 404)

        if request.method == 'GET':
            if base_learner_origin.final:
                raise exceptions.UserError('Base learner origin {} '
                                           'is already final'.format(id))
            if not base_learner_origin.validation_results:
                raise exceptions.UserError('Base learner origin {} has not yet been '
                                           'verified on a dataset'.format(id))
            base_learner = base_learner_origin.return_estimator()
            validation_results, hyperparameters = functions.verify_estimator_class(
                base_learner,
                base_learner_origin.meta_feature_generator,
                base_learner_origin.metric_generators,
                base_learner_origin.validation_results['dataset']
            )
            base_learner_origin.validation_results = {
                'dataset': base_learner_origin.validation_results['dataset'],
                'metrics': validation_results
            }
            base_learner_origin.hyperparameters = hyperparameters
            base_learner_origin.final = True
            session.add(base_learner_origin)
            session.commit()
            return jsonify(base_learner_origin.serialize)
예제 #2
0
 def test_verify_estimator_class(self):
     np.random.seed(8)
     performance_dict, hyperparameters = functions.verify_estimator_class(
         RandomForestClassifier(),
         'predict_proba',
         dict(Accuracy=self.source),
         self.dataset_properties
     )
     assert round(performance_dict['Accuracy'], 3) == 0.8
     assert hyperparameters == {
         'warm_start': False,
         'oob_score': False,
         'n_jobs': 1,
         'verbose': 0,
         'max_leaf_nodes': None,
         'bootstrap': True,
         'min_samples_leaf': 1,
         'n_estimators': 10,
         'min_samples_split': 2,
         'min_weight_fraction_leaf': 0.0,
         'criterion': 'gini',
         'random_state': None,
         'min_impurity_split': None,
         'min_impurity_decrease': 0.0,
         'max_features': 'auto',
         'max_depth': None,
         'class_weight': None
     }
예제 #3
0
파일: views.py 프로젝트: BigRLab/xcessiv
def verify_base_learner_origin(id):
    path = functions.get_path_from_query_string(request)

    with functions.DBContextManager(path) as session:
        base_learner_origin = session.query(
            models.BaseLearnerOrigin).filter_by(id=id).first()
        if base_learner_origin is None:
            raise exceptions.UserError(
                'Base learner origin {} not found'.format(id), 404)

        if request.method == 'POST':
            req_body = request.get_json()
            if base_learner_origin.final:
                raise exceptions.UserError('Base learner origin {} '
                                           'is already final'.format(id))
            base_learner = base_learner_origin.return_estimator()
            validation_results, hyperparameters = functions.verify_estimator_class(
                base_learner, base_learner_origin.meta_feature_generator,
                base_learner_origin.metric_generators, req_body['dataset'])
            base_learner_origin.validation_results = {
                req_body['dataset']: validation_results
            }
            base_learner_origin.hyperparameters = hyperparameters
            session.add(base_learner_origin)
            session.commit()
            return jsonify(base_learner_origin.serialize)
예제 #4
0
 def test_non_serializable_parameters(self):
     pipeline = Pipeline([('pca', PCA()), ('rf', RandomForestClassifier())])
     performance_dict, hyperparameters = functions.verify_estimator_class(
         pipeline,
         'predict_proba',
         dict(Accuracy=self.source),
         self.dataset_properties
     )
     assert functions.is_valid_json(hyperparameters)