示例#1
0
    def _write_model_to_db(self, class_path, parameters, feature_names,
                           model_hash, trained_model, model_group_id,
                           misc_db_parameters):
        """Writes model and feature importance data to a database
        Will overwrite the data of any previous versions
        (any existing model that shares a hash)

        Args:
            class_path (string) A full classpath to the model class
            parameters (dict) hyperparameters to give to the model constructor
            feature_names (list) feature names in order given to model
            model_hash (string) a unique id for the model
            trained_model (object) a trained model object
            misc_db_parameters (dict) params to pass through to the database
        """
        saved_model_id = retrieve_model_id_from_hash(self.db_engine,
                                                     model_hash)
        if saved_model_id:
            # logging.warning('deleting existing model %s', existing_model.model_id)
            # existing_model.delete(session)
            # session.commit()
            logging.warning('model meta data already stored %s',
                            saved_model_id)
            return saved_model_id

        session = self.sessionmaker()
        model = Model(model_hash=model_hash,
                      model_type=class_path,
                      model_parameters=parameters,
                      model_group_id=model_group_id,
                      experiment_hash=self.experiment_hash,
                      **misc_db_parameters)
        session.add(model)

        feature_importance = get_feature_importances(trained_model)
        temp_df = pandas.DataFrame({'feature_importance': feature_importance})
        features_index = temp_df.index.tolist()
        rankings_abs = temp_df['feature_importance'].rank(method='dense',
                                                          ascending=False)
        rankings_pct = temp_df['feature_importance'].rank(method='dense',
                                                          ascending=False,
                                                          pct=True)
        for feature_index, importance, rank_abs, rank_pct in zip(
                features_index, feature_importance, rankings_abs,
                rankings_pct):
            feature_importance = FeatureImportance(
                model=model,
                feature_importance=round(float(importance), 10),
                feature=feature_names[feature_index],
                rank_abs=int(rank_abs),
                rank_pct=round(float(rank_pct), 10))
            session.add(feature_importance)
        session.commit()
        model_id = model.model_id
        session.close()
        return model_id
示例#2
0
    def _write_model_to_db(
        self,
        class_path,
        parameters,
        feature_names,
        model_hash,
        trained_model,
        model_group_id,
        misc_db_parameters
    ):
        """Writes model and feature importance data to a database
        Will overwrite the data of any previous versions
        (any existing model that shares a hash)

        If the replace flag on the object is set, the existing version of the model
        will have its non-unique attributes (e.g. timestamps) updated,
        and feature importances fully replaced.

        If the replace flag on the object is not set, the existing model metadata
        and feature importances will be used.

        Args:
            class_path (string) A full classpath to the model class
            parameters (dict) hyperparameters to give to the model constructor
            feature_names (list) feature names in order given to model
            model_hash (string) a unique id for the model
            trained_model (object) a trained model object
            misc_db_parameters (dict) params to pass through to the database
        """
        model_id = retrieve_model_id_from_hash(self.db_engine, model_hash)
        if model_id and not self.replace:
            logging.info(
                'Metadata for model_id %s found in database. Reusing model metadata.',
                model_id
            )
            return model_id
        else:
            model = Model(
                model_hash=model_hash,
                model_type=class_path,
                model_parameters=parameters,
                model_group_id=model_group_id,
                experiment_hash=self.experiment_hash,
                **misc_db_parameters
            )
            session = self.sessionmaker()
            if model_id:
                logging.info('Found model id %s, updating non-unique attributes', model_id)
                model.model_id = model_id
                session.merge(model)
                session.commit()
            else:
                session.add(model)
                session.commit()
                model_id = model.model_id
                logging.info('Added new model id %s', model_id)
            session.close()

        logging.info('Saving feature importances for model_id %s', model_id)
        self._save_feature_importances(
            model_id,
            get_feature_importances(trained_model),
            feature_names
        )
        logging.info('Done saving feature importances for model_id %s', model_id)
        return model_id
def test_correct_feature_importances_for_rf(trained_models):
    feature_importances = get_feature_importances(trained_models['RF'])

    assert feature_importances.shape == (30, )
def test_correct_feature_importances_for_lr(trained_models):
    feature_importances = get_feature_importances(trained_models['LR'])

    ## It returns the intercept, too
    assert feature_importances.shape == (30, )
def test_throwing_warning_if_lr(trained_models):
    with pytest.warns(UserWarning):
        get_feature_importances(trained_models['LR'])
示例#6
0
def test_correct_feature_importances_for_dummy(trained_models):
    feature_importances = get_feature_importances(trained_models['Dummy'])
    assert feature_importances is None
示例#7
0
def test_correct_feature_importances_for_svc_wo_linear_kernel(trained_models):
    feature_importances = get_feature_importances(
        trained_models['SVC_wo_linear_kernel'])
    assert feature_importances is None
示例#8
0
def test_throwing_warning_if_SVC_wo_linear_kernel(trained_models):
    with pytest.warns(UserWarning):
        get_feature_importances(trained_models['SVC_wo_linear_kernel'])
示例#9
0
def test_throwing_warning_if_dummyclassifier(trained_models):
    with pytest.warns(UserWarning):
        get_feature_importances(trained_models['Dummy'])