Esempio n. 1
0
def import_scorer(scoring):
    '''Import a scoring function or find it in METRICS'''
    if not hasattr(scoring, 'fit'):
        if scoring in METRICS:
            scoring = import_callable(METRICS[scoring])
            requires_y = True
        else:
            scoring = import_callable(scoring)
            required_args, kwargs, has_var_kwargs = get_args_kwargs_defaults(scoring)
            requires_y = 'y_true' in required_args
    return (scoring, requires_y)
Esempio n. 2
0
    def _validate_one_data_source(self, name, ds):
        '''Validate one data source within "data_sources"
        section of config'''

        if not name or not isinstance(name, string_types):
            raise ElmConfigError('Expected a "name" key in {}'.format(d))
        sampler = ds.get('sampler')
        if sampler:
            self._validate_custom_callable(sampler,
                                    True,
                                    'train:{} sampler'.format(name))
        if 'layer_specs' in ds:
            ds['layer_specs'] = self._validate_layer_specs(ds.get('layer_specs'), name)
        reader = ds.get('reader')
        reader_words = ('hdf4', 'hdf5', 'tif', 'netcdf')
        if not reader:
            ds['reader'] = None
        if reader and not reader in self.readers and reader not in reader_words:
            raise ElmConfigError('Data source config dict {} '
                                 'refers to a "reader" {} that is not defined in '
                                 '"readers" or in the tuple {}'.format(reader, self.readers, reader_words))
        s = ds.get('args_list')
        if s and isinstance(s, (list, tuple)):
            def anon(*args, **kwargs):
                return s
            ds['args_list'] = anon
        elif s:
            try:
                ds['args_list'] = import_callable(s)
            except Exception as e:
                raise ElmConfigError('data_source:{} uses a args_list {} that '
                                     'is neither importable nor in '
                                     '"args_list" dict'.format(name, s))
        self._validate_selection_kwargs(ds, name)
        self.data_sources[name] = ds
Esempio n. 3
0
 def _validate_custom_callable(self, func_or_not, required, context):
     '''Validate a callable given like "numpy:mean" can be imported'''
     if callable(func_or_not):
         return func_or_not
     if func_or_not or (not func_or_not and required):
         if not isinstance(func_or_not, string_types):
             raise ElmConfigError('In {} expected {} to be a '
                                    'string'.format(func_or_not, context))
     return import_callable(func_or_not, required=required, context=context)
Esempio n. 4
0
    def _validate_one_train_entry(self, name, t):
        '''Validate one dict within "train" or "transform" section of config'''
        has_fit_func, requires_y, no_selection = self._validate_train_or_transform_funcs(name, t)
        kwargs_fields = tuple(k for k in t if k.endswith('_kwargs'))
        for k in kwargs_fields:
            self._validate_type(t[k], 'train:{}'.format(k), dict)
        mod = t.get('model_selection')

        if mod:
            t['model_selection'] = import_callable(mod)
        if t.get('sort_fitness') is not None:
            self._validate_custom_callable(t.get('sort_fitness'),
                                           True,
                                           'train:{} (sort_fitness)'.format(repr(t.get('sort_fitness'))))
        ms = t.get('model_scoring')
        if ms is not None:
            self._validate_type(ms, 'train:{} (model_scoring)'.format(name),
                                (str, numbers.Number, tuple))
            if not ms in self.model_scoring:
                raise ElmConfigError('train:{}\'s model_scoring: {} is not a '
                                     'key in config\'s model_scoring '
                                     'dict'.format(name, ms))
        self.config['train'][name] = t
Esempio n. 5
0
from elm.config.util import import_callable
from elm.model_selection.util import get_args_kwargs_defaults
PARTIAL_FIT_MODEL_STR = ('sklearn.naive_bayes:MultinomialNB',
                         'sklearn.naive_bayes:BernoulliNB',
                         'sklearn.linear_model:Perceptron',
                         'sklearn.linear_model:SGDClassifier',
                         'sklearn.linear_model:PassiveAggressiveClassifier',
                         'sklearn.linear_model:SGDRegressor',
                         'sklearn.linear_model:PassiveAggressiveRegressor',
                         'sklearn.cluster:MiniBatchKMeans',
                         'sklearn.cluster:Birch',
                         'sklearn.neural_network:BernoulliRBM')

PARTIAL_FIT_MODEL_DICT = {k: import_callable(k) for k in PARTIAL_FIT_MODEL_STR}

FIT_PREDICT_MODELS_STR = [
    'sklearn.tree:DecisionTreeClassifier',
    'sklearn.tree:DecisionTreeRegressor',
    'sklearn.tree:ExtraTreeClassifier',
    'sklearn.tree:ExtraTreeRegressor',
]
linear_models_predict = [
    'ARDRegression', 'BayesianRidge', 'ElasticNet', 'ElasticNetCV', 'Lars',
    'LarsCV', 'Lasso', 'LassoCV', 'LassoLars', 'LassoLarsCV', 'LassoLarsIC',
    'LinearRegression', 'LogisticRegression', 'LogisticRegressionCV',
    'MultiTaskElasticNet', 'MultiTaskLassoCV', 'MultiTaskElasticNetCV',
    'MultiTaskLasso', 'OrthogonalMatchingPursuit',
    'OrthogonalMatchingPursuitCV', 'PassiveAggressiveClassifier',
    'PassiveAggressiveRegressor', 'Perceptron', 'RidgeClassifier', 'Ridge',
    'RidgeClassifierCV', 'RidgeCV', 'SGDClassifier', 'SGDRegressor',
    'TheilSenRegressor'