Exemplo n.º 1
0
def test_validate_preprocess():
    PLUGINS = _DictPluginManager()
    PLUGINS.set('vergeml.operation', 'augment', AugmentOperation)
    VALIDATORS = {'data': ValidateData(plugins=PLUGINS)}
    apply_config({'data': {
        'preprocess': [{
            'op': 'augment',
            'variants': 4
        }]
    }}, VALIDATORS)
    assert VALIDATORS['data'].values == {
        'data': {
            'cache': '*auto*',
            'input': {
                'type': None
            },
            'output': {
                'type': None
            },
            'preprocess': [{
                'op': 'augment',
                'variants': 4
            }]
        }
    }
Exemplo n.º 2
0
    def set_defaults(self, cmd, args, plugins=PLUGINS):
        if self.model:
            self.model.set_defaults(cmd, args, self)
        validators = dict(device=ValidateDevice('device', plugins),
                          data=ValidateData('data', plugins))

        config = dict(device=self.get('device'), data=self.get('data'))
        apply_config(config, validators)
        # update env from validators
        for _, plugin in validators.items():
            for k, v in plugin.values.items():
                self._config[k] = v
Exemplo n.º 3
0
def test_validate_preprocess_invalid():
    PLUGINS = _DictPluginManager()
    PLUGINS.set('vergeml.operation', 'augment', AugmentOperation)
    VALIDATORS = {'data': ValidateData(plugins=PLUGINS)}
    with pytest.raises(VergeMLError, match=r".*Did you mean 'variants'.*"):
        apply_config(
            {'data': {
                'preprocess': [{
                    'op': 'augment',
                    'variantz': 4
                }]
            }}, VALIDATORS)
Exemplo n.º 4
0
def _load_and_configure(file, label, validators):
    doc = load_yaml_file(file, label)
    try:
        doc = apply_config(doc, validators)
        if 'random-seed' in doc and not isinstance(doc['random-seed'], int):
            raise VergeMLError('Invalid value option random-seed.',
                               'random-seed must be an integer value.',
                               hint_type='value',
                               hint_key='random-seed')
    except VergeMLError as e:
        if e.hint_key:
            key, kind = e.hint_key, e.hint_type
            with open(file) as f:
                definition = yaml_find_definition(f, key, kind)
            if definition:
                line, column, length = definition
                message = display_err_in_file(file, line, column, str(e),
                                              length)
                e.message = message
                # clear suggestion because it is already contained in the formatted error message.
                e.suggestion = None
                raise e
            else:
                raise e
        else:
            raise e
    return doc
Exemplo n.º 5
0
def test_input_output():
    PLUGINS = _DictPluginManager()
    PLUGINS.set('vergeml.io', 'image', ImageSource)
    VALIDATORS = {'data': ValidateData('image', plugins=PLUGINS)}
    apply_config(
        {'data': {
            'input': {
                'type': 'image'
            },
            'output': {
                'type': 'image'
            }
        }},
        validators=VALIDATORS)
    assert VALIDATORS['data'].values['data']['input']['type'] == 'image'
    assert VALIDATORS['data'].values['data']['output']['type'] == 'image'
Exemplo n.º 6
0
def test_apply_empty_config():
    VALIDATORS = {'device': ValidateDevice()}
    assert apply_config({}, VALIDATORS) == {}
    assert VALIDATORS['device'].values == {
        'device': {
            'id': 'auto',
            'memory': 'auto',
            'grow-memory': False
        }
    }
Exemplo n.º 7
0
def test_apply_config_image_invalid():
    PLUGINS = _DictPluginManager()
    PLUGINS.set('vergeml.io', 'image', ImageSource)
    VALIDATORS = {'data': ValidateData(plugins=PLUGINS)}
    with pytest.raises(VergeMLError):
        assert apply_config(
            {'data': {
                'input': {
                    'type': 'image',
                    'input-patternz': '*.jpg'
                }
            }}, VALIDATORS) == {}
Exemplo n.º 8
0
def test_apply_config():
    VALIDATORS = {'device': ValidateDevice()}
    assert apply_config({
        'device': 'gpu',
        'model': 'inception-v3'
    }, VALIDATORS) == {
        'model': 'inception-v3'
    }
    assert VALIDATORS['device'].values == {
        'device': {
            'id': 'gpu:0',
            'memory': 'auto',
            'grow-memory': False
        }
    }
Exemplo n.º 9
0
def test_apply_config_image():
    PLUGINS = _DictPluginManager()
    PLUGINS.set('vergeml.io', 'image', ImageSource)
    VALIDATORS = {'data': ValidateData(plugins=PLUGINS)}
    assert apply_config(
        {'data': {
            'input': {
                'type': 'image',
                'input-patterns': '*.jpg'
            }
        }}, VALIDATORS) == {}
    assert VALIDATORS['data'].values == {
        'data': {
            'input': {
                'type': 'image',
                'input-patterns': '*.jpg'
            },
            'output': {
                'type': None
            },
            'cache': '*auto*',
            'preprocess': []
        }
    }
Exemplo n.º 10
0
    def __init__(self,
                 model=None,
                 project_file=None,
                 samples_dir=None,
                 test_split=None,
                 val_split=None,
                 cache_dir=None,
                 random_seed=None,
                 trainings_dir=None,
                 project_dir=None,
                 AI=None,
                 is_global_instance=False,
                 config={},
                 plugins=PLUGINS,
                 display=DISPLAY):
        """Configure, train and save the results.

        :param model:           Name of the model plugin.
        :param project_file:    Optional path to the project file.
        :param samples_dir:     The directory where samples can be found. [default: samples]
        :param test_split:      The test split. [default: 10%]
        :param val_split:       The val split. [default: 10%]
        :param cache_dir:       The directory used for caching [default: .cache]
        :param random_seed:     Random seed. [default 2204]
        :param trainings_dir:   The directory to save training results to. [default: trainings]
        :param project_dir:     The directory of the project. [default: current directory]
        :param AI:              Optional name of a trained AI.
        :is_global_instance:    If true, this env can be accessed under the global var env.ENV. [default: false] 
        :config:                Additional configuration to pass to env, i.e. if not using a project file
        """

        super().__init__()

        # when called from the command line, we need to have a global instance
        if is_global_instance:
            global ENV
            ENV = self

        # setup the display
        self.display = display
        # set the name of the AI if given
        self.AI = AI
        # this holds the model object (not the name of the model)
        self.model = None
        # the results class (responsible for updating data.yaml with the latest results during training)
        self.results = None
        # when a training is started, this holds the object responsible for coordinating the training
        self.training = None
        # hold a proxy to the data loader
        self._data = None

        self.plugins = plugins

        # set up the base options from constructor arguments
        self._config = {}
        self._config['samples-dir'] = samples_dir
        self._config['test-split'] = test_split
        self._config['val-split'] = val_split
        self._config['cache-dir'] = cache_dir
        self._config['random-seed'] = random_seed
        self._config['trainings-dir'] = trainings_dir
        self._config['model'] = model

        validators = {}
        # add validators for commands
        for k, v in plugins.all('vergeml.cmd').items():
            cmd = Command.discover(v)
            validators[cmd.name] = ValidateOptions(cmd.options,
                                                   k,
                                                   plugins=plugins)
        # now it gets a bit tricky - we need to peek at the model name
        # to find the right validators to create for model commands.
        peek_model_name = model
        peek_trainings_dir = trainings_dir
        # to do this, we have to first have a look at the project file
        try:
            project_doc = load_yaml_file(project_file) if project_file else {}
            # only update model name if empty (project file does not override command line)
            peek_model_name = peek_model_name or project_doc.get('model', None)
            # pick up trainings-dir in the same way
            peek_trainings_dir = peek_trainings_dir or project_doc.get(
                'trainings-dir', None)
            # if we don't have a trainings dir yet, set to default
            peek_trainings_dir = peek_trainings_dir or os.path.join(
                project_dir or "", "trainings")
            # now, try to load the data.yaml file and see if we have a model definition there
            data_doc = load_yaml_file(peek_trainings_dir, AI,
                                      "data.yaml") if AI else {}
            # if we do, this overrides everything, also the one from the command line
            peek_model_name = data_doc.get('model', peek_model_name)
            # finally, if we have a model name, set up validators
            if peek_model_name:
                for fn in Command.find_functions(plugins.get(
                        "vergeml.model", peek_model_name),
                                                 plugins=plugins):
                    cmd = Command.discover(fn)
                    validators[cmd.name] = ValidateOptions(
                        cmd.options, cmd.name, plugins)
        except Exception:
            # in this case we don't care if something went wrong - the error
            # will be reported later
            pass
        # finally, validators for device and data sections
        validators['device'] = ValidateDevice('device', plugins)
        validators['data'] = ValidateData('data', plugins)

        # merge project file
        if project_file:
            doc = _load_and_configure(project_file, 'project file', validators)
            # the project file DOES NOT override values passed to the environment
            # TODO reserved: hyperparameters and results
            for k, v in doc.items():
                if not k in self._config or self._config[k] is None:
                    self._config[k] = v

        # after the project file is loaded, fill missing values
        project_dir = project_dir or ''
        defaults = {
            'samples-dir': os.path.join(project_dir, "samples"),
            'test-split': '10%',
            'val-split': '10%',
            'cache-dir': os.path.join(project_dir, ".cache"),
            'random-seed': 2204,
            'trainings-dir': os.path.join(project_dir, "trainings"),
        }
        for k, v in defaults.items():
            if self._config[k] is None:
                self._config[k] = v

        # verify split values
        for split in ('val-split', 'test-split'):
            spltype, splval = parse_split(self._config[split])
            if spltype == 'dir':
                path = os.path.join(project_dir, splval)
                if not os.path.exists(path):
                    raise VergeMLError(
                        f"Invalid value for option {split} - no such directory: {splval}",
                        f"Please set {split} to a percentage, number or directory.",
                        hint_key=split,
                        hint_type='value',
                        help_topic='split')
                self._config[split] = path

        # need to have data_file variable in outer scope for later when reporting errors
        data_file = None
        if self.AI:
            ai_path = os.path.join(self._config['trainings-dir'], self.AI)
            if not os.path.exists(ai_path):
                raise VergeMLError("AI not found: {}".format(self.AI))
            # merge data.yaml
            data_file = os.path.join(self._config['trainings-dir'], self.AI,
                                     'data.yaml')
            if not os.path.exists(data_file):
                raise VergeMLError(
                    "data.yaml file not found for AI {}: {}".format(
                        self.AI, data_file))
            doc = load_yaml_file(data_file, 'data file')
            self._config['hyperparameters'] = doc.get('hyperparameters', {})
            self._config['results'] = doc.get('results', {})
            self._config['model'] = doc.get('model')
            self.results = _Results(self, data_file)

        try:
            # merge device and data config
            self._config.update(apply_config(config, validators))
        except VergeMLError as e:
            # improve the error message when this runs on the command line
            if is_global_instance and e.hint_key:
                key = e.hint_key
                e.message = f"Option --{key}: " + e.message
            raise e

        if self._config['model']:
            # load the model plugin
            modelname = self._config['model']
            self.model = plugins.get("vergeml.model", modelname)

            if not self.model:
                message = f"Unknown model name '{modelname}'"
                suggestion = did_you_mean(plugins.keys('vergeml.model'),
                                          modelname) or "See 'ml help models'."

                # if model was passed in via --model
                if model and is_global_instance:
                    message = f"Invalid value for option --model: {message}"
                else:
                    res = None
                    if not res and data_file:
                        # first check if model was defined in the data file
                        res = _check_definition(data_file, 'model', 'value')
                    if not res and project_file:
                        # next check the project file
                        res = _check_definition(project_file, 'model', 'value')
                    if res:
                        filename, definition = res
                        line, column, length = definition
                        # display a nice error message
                        message = display_err_in_file(
                            filename, line, column, f"{message} {suggestion}",
                            length)
                        # set suggestion to None since it is now contained in message
                        suggestion = None
                raise VergeMLError(message, suggestion)
            else:
                # instantiate the model plugin
                self.model = self.model(modelname, plugins)

        # update env from validators
        for _, plugin in validators.items():
            for k, v in plugin.values.items():
                self._config[k] = v

        # always set up numpy and python
        self.configure('python')
        self.configure('numpy')
Exemplo n.º 11
0
def test_config_invalid():
    VALIDATORS = {'device': ValidateDevice()}
    with pytest.raises(VergeMLError):
        apply_config({'device': {'id': 'cpu', 'invalid': 'true'}}, VALIDATORS)
Exemplo n.º 12
0
def test_config_dict():
    VALIDATORS = {'device': ValidateDevice()}
    res = apply_config({'device': {'id': 'cpu'}}, VALIDATORS)
    assert (res == {})
    assert (VALIDATORS['device'].values['device']['id'] == 'cpu')