Exemple #1
0
def predict_with_model(config_path):
    config = read_json(config_path)
    set_deeppavlov_root(config)

    reader_config = config['dataset_reader']
    reader = get_model(reader_config['name'])()
    data_path = expand_path(reader_config.get('data_path', ''))
    read_params = {k: v for k, v in reader_config.items() if k not in ['name', 'data_path']}
    data = reader.read(data_path, **read_params)

    iterator_config = config['dataset_iterator']
    iterator: MorphoTaggerDatasetIterator =\
        from_params(iterator_config, data=data)

    model = build_model_from_config(config, load_trained=True)
    answers = [None] * len(iterator.test)
    batch_size = config['predict'].get("batch_size", -1)
    for indexes, (x, _) in iterator.gen_batches(
            batch_size=batch_size, data_type="test", shuffle=False, return_indexes=True):
        y = model(x)
        for i, elem in zip(indexes, y):
            answers[i] = elem
    outfile = config['predict'].get("outfile")
    if outfile is not None:
        outfile = Path(outfile)
        if not outfile.exists():
            outfile.parent.mkdir(parents=True, exist_ok=True)
        with open(outfile, "w", encoding="utf8") as fout:
            for elem in answers:
                fout.write(elem + "\n")
    return answers
Exemple #2
0
def read_data_by_config(config: dict):
    """Read data by dataset_reader from specified config."""
    dataset_config = config.get('dataset', None)

    if dataset_config:
        config.pop('dataset')
        ds_type = dataset_config['type']
        if ds_type == 'classification':
            reader = {'class_name': 'basic_classification_reader'}
            iterator = {'class_name': 'basic_classification_iterator'}
            config['dataset_reader'] = {**dataset_config, **reader}
            config['dataset_iterator'] = {**dataset_config, **iterator}
        else:
            raise Exception("Unsupported dataset type: {}".format(ds_type))

    try:
        reader_config = dict(config['dataset_reader'])
    except KeyError:
        raise ConfigError("No dataset reader is provided in the JSON config.")

    reader = get_model(reader_config.pop('class_name'))()
    data_path = reader_config.pop('data_path', '')
    if isinstance(data_path, list):
        data_path = [expand_path(x) for x in data_path]
    else:
        data_path = expand_path(data_path)

    return reader.read(data_path, **reader_config)
Exemple #3
0
def from_params(params: Dict, mode: str = 'infer', **kwargs) -> Component:
    """Builds and returns the Component from corresponding dictionary of parameters."""
    # what is passed in json:
    config_params = {k: _resolve(v) for k, v in params.items()}

    # get component by reference (if any)
    if 'ref' in config_params:
        try:
            return _refs[config_params['ref']]
        except KeyError:
            e = ConfigError(
                'Component with id "{id}" was referenced but not initialized'.
                format(id=config_params['ref']))
            log.exception(e)
            raise e

    elif 'config_path' in config_params:
        from deeppavlov.core.commands.infer import build_model_from_config
        deeppavlov_root = get_deeppavlov_root()
        refs = _refs.copy()
        _refs.clear()
        config = read_json(expand_path(config_params['config_path']))
        model = build_model_from_config(config)
        set_deeppavlov_root({'deeppavlov_root': deeppavlov_root})
        _refs.clear()
        _refs.update(refs)
        return model

    elif 'class' in config_params:
        cls = cls_from_str(config_params.pop('class'))
    else:
        cls_name = config_params.pop('name', None)
        if not cls_name:
            e = ConfigError(
                'Component config has no `name` nor `ref` or `class` fields')
            log.exception(e)
            raise e
        cls = get_model(cls_name)

    # find the submodels params recursively
    config_params = {k: _init_param(v, mode) for k, v in config_params.items()}

    try:
        spec = inspect.getfullargspec(cls)
        if 'mode' in spec.args + spec.kwonlyargs or spec.varkw is not None:
            kwargs['mode'] = mode

        component = cls(**dict(config_params, **kwargs))
        try:
            _refs[config_params['id']] = component
        except KeyError:
            pass
    except Exception:
        log.exception("Exception in {}".format(cls))
        raise

    return component
Exemple #4
0
def from_params(params: Dict, mode: str = 'infer', serialized: Any = None, **kwargs) -> Component:
    """Builds and returns the Component from corresponding dictionary of parameters."""
    # what is passed in json:
    config_params = {k: _resolve(v) for k, v in params.items()}

    # get component by reference (if any)
    if 'ref' in config_params:
        try:
            component = _refs[config_params['ref']]
            if serialized is not None:
                component.deserialize(serialized)
            return component
        except KeyError:
            e = ConfigError('Component with id "{id}" was referenced but not initialized'
                            .format(id=config_params['ref']))
            log.exception(e)
            raise e

    elif 'config_path' in config_params:
        from deeppavlov.core.commands.infer import build_model
        refs = _refs.copy()
        _refs.clear()
        config = parse_config(expand_path(config_params['config_path']))
        model = build_model(config, serialized=serialized)
        _refs.clear()
        _refs.update(refs)
        return model

    cls_name = config_params.pop('class_name', None)
    if not cls_name:
        e = ConfigError('Component config has no `class_name` nor `ref` fields')
        log.exception(e)
        raise e
    cls = get_model(cls_name)

    # find the submodels params recursively
    config_params = {k: _init_param(v, mode) for k, v in config_params.items()}

    try:
        spec = inspect.getfullargspec(cls)
        if 'mode' in spec.args+spec.kwonlyargs or spec.varkw is not None:
            kwargs['mode'] = mode

        component = cls(**dict(config_params, **kwargs))
        try:
            _refs[config_params['id']] = component
        except KeyError:
            pass
    except Exception:
        log.exception("Exception in {}".format(cls))
        raise

    if serialized is not None:
        component.deserialize(serialized)
    return component
    def __init__(self,
                 actions2slots_path: Union[str, Path],
                 api_call_action: str,
                 data_path: Union[str, Path],
                 dataset_reader_class="dstc2_reader",
                 debug=False):
        self.debug = debug

        if self.debug:
            log.debug(f"BEFORE {self.__class__.__name__} init(): "
                      f"actions2slots_path={actions2slots_path}, "
                      f"api_call_action={api_call_action}, debug={debug}")

        self._dataset_reader = get_model(dataset_reader_class)

        individual_actions2slots = self._load_actions2slots_mapping(
            actions2slots_path)
        possible_actions_combinations_tuples = sorted(
            set(actions_combination_tuple for actions_combination_tuple in
                self._extract_actions_combinations(data_path)),
            key=lambda x: '+'.join(x))

        self.action_tuples2ids = {
            action_tuple: action_tuple_idx
            for action_tuple_idx, action_tuple in enumerate(
                possible_actions_combinations_tuples)
        }  # todo: typehint tuples somehow
        self.ids2action_tuples = {
            v: k
            for k, v in self.action_tuples2ids.items()
        }

        self.action_tuples_ids2slots = {}  # todo: typehint tuples somehow
        for actions_combination_tuple in possible_actions_combinations_tuples:
            actions_combination_slots = set(
                slot for action in actions_combination_tuple
                for slot in individual_actions2slots.get(action, []))
            actions_combination_tuple_id = self.action_tuples2ids[
                actions_combination_tuple]
            self.action_tuples_ids2slots[
                actions_combination_tuple_id] = actions_combination_slots

        self._api_call_id = -1
        if api_call_action is not None:
            api_call_action_as_tuple = (api_call_action, )
            self._api_call_id = self.action_tuples2ids[
                api_call_action_as_tuple]

        if self.debug:
            log.debug(f"AFTER {self.__class__.__name__} init(): "
                      f"actions2slots_path={actions2slots_path}, "
                      f"api_call_action={api_call_action}, debug={debug}")
Exemple #6
0
def predict_with_model(config_path: [Path, str]) -> List[Optional[List[str]]]:
    """Returns predictions of morphotagging model given in config :config_path:.

    Args:
        config_path: a path to config

    Returns:
        a list of morphological analyses for each sentence. Each analysis is either a list of tags
        or a list of full CONLL-U descriptions.

    """
    config = parse_config(config_path)

    reader_config = config['dataset_reader']
    reader = get_model(reader_config['class_name'])()
    data_path = expand_path(reader_config.get('data_path', ''))
    read_params = {
        k: v
        for k, v in reader_config.items()
        if k not in ['class_name', 'data_path']
    }
    data: Dict = reader.read(data_path, **read_params)

    iterator_config = config['dataset_iterator']
    iterator: MorphoTaggerDatasetIterator = from_params(iterator_config,
                                                        data=data)

    model = build_model(config, load_trained=True)
    answers = [None] * len(iterator.test)
    batch_size = config['predict'].get("batch_size", -1)
    for indexes, (x, _) in iterator.gen_batches(batch_size=batch_size,
                                                data_type="test",
                                                shuffle=False,
                                                return_indexes=True):
        y = model(x)
        for i, elem in zip(indexes, y):
            answers[i] = elem
    outfile = config['predict'].get("outfile")
    if outfile is not None:
        outfile = Path(outfile)
        if not outfile.exists():
            outfile.parent.mkdir(parents=True, exist_ok=True)
        with open(outfile, "w", encoding="utf8") as fout:
            for elem in answers:
                fout.write(elem + "\n")
    return answers
Exemple #7
0
def read_data_by_config(config: dict):
    """Read data by dataset_reader from specified config."""
    dataset_config = config.get('dataset', None)

    if dataset_config:
        config.pop('dataset')
        ds_type = dataset_config['type']
        if ds_type == 'classification':
            reader = {'name': 'basic_classification_reader'}
            iterator = {'name': 'basic_classification_iterator'}
            config['dataset_reader'] = {**dataset_config, **reader}
            config['dataset_iterator'] = {**dataset_config, **iterator}
        else:
            raise Exception("Unsupported dataset type: {}".format(ds_type))

    data = []
    reader_config = config.get('dataset_reader', None)

    if reader_config:
        reader_config = config['dataset_reader']
        if 'class' in reader_config:
            c = reader_config.pop('class')
            try:
                module_name, cls_name = c.split(':')
                reader = getattr(importlib.import_module(module_name),
                                 cls_name)()
            except ValueError:
                e = ConfigError(
                    'Expected class description in a `module.submodules:ClassName` form, but got `{}`'
                    .format(c))
                log.exception(e)
                raise e
        else:
            reader = get_model(reader_config.pop('name'))()
        data_path = reader_config.pop('data_path', '')
        if isinstance(data_path, list):
            data_path = [expand_path(x) for x in data_path]
        else:
            data_path = expand_path(data_path)
        data = reader.read(data_path, **reader_config)
    else:
        log.warning("No dataset reader is provided in the JSON config.")

    return data
Exemple #8
0
def predict_with_model(config_path: [Path, str]) -> List[Optional[List[str]]]:
    """Returns predictions of morphotagging model given in config :config_path:.

    Args:
        config_path: a path to config

    Returns:
        a list of morphological analyses for each sentence. Each analysis is either a list of tags
        or a list of full CONLL-U descriptions.

    """
    config = parse_config(config_path)

    reader_config = config['dataset_reader']
    reader = get_model(reader_config['class_name'])()
    data_path = expand_path(reader_config.get('data_path', ''))
    read_params = {k: v for k, v in reader_config.items() if k not in ['class_name', 'data_path']}
    data: Dict = reader.read(data_path, **read_params)

    iterator_config = config['dataset_iterator']
    iterator: MorphoTaggerDatasetIterator = from_params(iterator_config, data=data)

    model = build_model(config, load_trained=True)
    answers = [None] * len(iterator.test)
    batch_size = config['predict'].get("batch_size", -1)
    for indexes, (x, _) in iterator.gen_batches(
            batch_size=batch_size, data_type="test", shuffle=False, return_indexes=True):
        y = model(x)
        for i, elem in zip(indexes, y):
            answers[i] = elem
    outfile = config['predict'].get("outfile")
    if outfile is not None:
        outfile = Path(outfile)
        if not outfile.exists():
            outfile.parent.mkdir(parents=True, exist_ok=True)
        with open(outfile, "w", encoding="utf8") as fout:
            for elem in answers:
                fout.write(elem + "\n")
    return answers
 def check_model_registry(model_name):
     try:
         get_model(model_name)
     except BaseException:
         return False
     return True
Exemple #10
0
def train_evaluate_model_from_config(config: [str, Path, dict],
                                     to_train=True,
                                     to_validate=True) -> None:
    if isinstance(config, (str, Path)):
        config = read_json(config)
    set_deeppavlov_root(config)

    import_packages(config.get('metadata', {}).get('imports', []))

    dataset_config = config.get('dataset', None)

    if dataset_config:
        config.pop('dataset')
        ds_type = dataset_config['type']
        if ds_type == 'classification':
            reader = {'name': 'basic_classification_reader'}
            iterator = {'name': 'basic_classification_iterator'}
            config['dataset_reader'] = {**dataset_config, **reader}
            config['dataset_iterator'] = {**dataset_config, **iterator}
        else:
            raise Exception("Unsupported dataset type: {}".format(ds_type))

    data = []
    reader_config = config.get('dataset_reader', None)

    if reader_config:
        reader_config = config['dataset_reader']
        if 'class' in reader_config:
            c = reader_config.pop('class')
            try:
                module_name, cls_name = c.split(':')
                reader = getattr(importlib.import_module(module_name),
                                 cls_name)()
            except ValueError:
                e = ConfigError(
                    'Expected class description in a `module.submodules:ClassName` form, but got `{}`'
                    .format(c))
                log.exception(e)
                raise e
        else:
            reader = get_model(reader_config.pop('name'))()
        data_path = expand_path(reader_config.pop('data_path', ''))
        data = reader.read(data_path, **reader_config)
    else:
        log.warning("No dataset reader is provided in the JSON config.")

    iterator_config = config['dataset_iterator']
    iterator: Union[DataLearningIterator,
                    DataFittingIterator] = from_params(iterator_config,
                                                       data=data)

    train_config = {
        'metrics': ['accuracy'],
        'validate_best': to_validate,
        'test_best': True
    }

    try:
        train_config.update(config['train'])
    except KeyError:
        log.warning('Train config is missing. Populating with default values')

    metrics_functions = list(
        zip(train_config['metrics'],
            get_metrics_by_names(train_config['metrics'])))

    if to_train:
        model = fit_chainer(config, iterator)

        if callable(getattr(model, 'train_on_batch', None)):
            _train_batches(model, iterator, train_config, metrics_functions)
        elif callable(getattr(model, 'fit_batches', None)):
            _fit_batches(model, iterator, train_config)
        elif callable(getattr(model, 'fit', None)):
            _fit(model, iterator, train_config)
        elif not isinstance(model, Chainer):
            log.warning('Nothing to train')

    if train_config['validate_best'] or train_config['test_best']:
        # try:
        #     model_config['load_path'] = model_config['save_path']
        # except KeyError:
        #     log.warning('No "save_path" parameter for the model, so "load_path" will not be renewed')
        model = build_model_from_config(config, load_trained=True)
        log.info('Testing the best saved model')

        if train_config['validate_best']:
            report = {
                'valid':
                _test_model(model, metrics_functions, iterator,
                            train_config.get('batch_size', -1), 'valid')
            }

            print(json.dumps(report, ensure_ascii=False))

        if train_config['test_best']:
            report = {
                'test':
                _test_model(model, metrics_functions, iterator,
                            train_config.get('batch_size', -1), 'test')
            }

            print(json.dumps(report, ensure_ascii=False))
Exemple #11
0
def train_evaluate_model_from_config(
        config: Union[str, Path, dict],
        iterator: Union[DataLearningIterator, DataFittingIterator] = None,
        *,
        to_train: bool = True,
        evaluation_targets: Optional[Iterable[str]] = None,
        to_validate: Optional[bool] = None,
        download: bool = False,
        start_epoch_num: Optional[int] = None,
        recursive: bool = False) -> Dict[str, Dict[str, float]]:
    """Make training and evaluation of the model described in corresponding configuration file."""
    config = parse_config(config)

    if download:
        deep_download(config)

    if to_train and recursive:
        for subconfig in get_all_elems_from_json(config['chainer'],
                                                 'config_path'):
            log.info(f'Training "{subconfig}"')
            train_evaluate_model_from_config(subconfig,
                                             download=False,
                                             recursive=True)

    import_packages(config.get('metadata', {}).get('imports', []))

    if iterator is None:
        try:
            data = read_data_by_config(config)
        except ConfigError as e:
            to_train = False
            log.warning(f'Skipping training. {e.message}')
        else:
            iterator = get_iterator_from_config(config, data)

    if 'train' not in config:
        log.warning('Train config is missing. Populating with default values')
    train_config = config.get('train')

    if start_epoch_num is not None:
        train_config['start_epoch_num'] = start_epoch_num

    if 'evaluation_targets' not in train_config and (
            'validate_best' in train_config or 'test_best' in train_config):
        log.warning(
            '"validate_best" and "test_best" parameters are deprecated.'
            ' Please, use "evaluation_targets" list instead')

        train_config['evaluation_targets'] = []
        if train_config.pop('validate_best', True):
            train_config['evaluation_targets'].append('valid')
        if train_config.pop('test_best', True):
            train_config['evaluation_targets'].append('test')

    trainer_class = get_model(train_config.pop('class_name', 'nn_trainer'))
    trainer = trainer_class(config['chainer'], **train_config)

    if to_train:
        trainer.train(iterator)

    res = {}

    if iterator is not None:
        if to_validate is not None:
            if evaluation_targets is None:
                log.warning(
                    '"to_validate" parameter is deprecated and will be removed in future versions.'
                    ' Please, use "evaluation_targets" list instead')
                evaluation_targets = ['test']
                if to_validate:
                    evaluation_targets.append('valid')
            else:
                log.warn(
                    'Both "evaluation_targets" and "to_validate" parameters are specified.'
                    ' "to_validate" is deprecated and will be ignored')

        res = trainer.evaluate(iterator,
                               evaluation_targets,
                               print_reports=True)
        trainer.get_chainer().destroy()

    res = {k: v['metrics'] for k, v in res.items()}

    return res
Exemple #12
0
def train_evaluate_model_from_config(config: [str, Path, dict], to_train: bool = True, to_validate: bool = True) -> None:
    """Make training and evaluation of the model described in corresponding configuration file."""
    if isinstance(config, (str, Path)):
        config = read_json(config)
    set_deeppavlov_root(config)

    import_packages(config.get('metadata', {}).get('imports', []))

    dataset_config = config.get('dataset', None)

    if dataset_config:
        config.pop('dataset')
        ds_type = dataset_config['type']
        if ds_type == 'classification':
            reader = {'name': 'basic_classification_reader'}
            iterator = {'name': 'basic_classification_iterator'}
            config['dataset_reader'] = {**dataset_config, **reader}
            config['dataset_iterator'] = {**dataset_config, **iterator}
        else:
            raise Exception("Unsupported dataset type: {}".format(ds_type))

    data = []
    reader_config = config.get('dataset_reader', None)

    if reader_config:
        reader_config = config['dataset_reader']
        if 'class' in reader_config:
            c = reader_config.pop('class')
            try:
                module_name, cls_name = c.split(':')
                reader = getattr(importlib.import_module(module_name), cls_name)()
            except ValueError:
                e = ConfigError('Expected class description in a `module.submodules:ClassName` form, but got `{}`'
                                .format(c))
                log.exception(e)
                raise e
        else:
            reader = get_model(reader_config.pop('name'))()
        data_path = reader_config.pop('data_path', '')
        if isinstance(data_path, list):
            data_path = [expand_path(x) for x in data_path]
        else:
            data_path = expand_path(data_path)
        data = reader.read(data_path, **reader_config)
    else:
        log.warning("No dataset reader is provided in the JSON config.")

    iterator_config = config['dataset_iterator']
    iterator: Union[DataLearningIterator, DataFittingIterator] = from_params(iterator_config,
                                                                             data=data)

    train_config = {
        'metrics': ['accuracy'],
        'validate_best': to_validate,
        'test_best': True,
        'show_examples': False
    }

    try:
        train_config.update(config['train'])
    except KeyError:
        log.warning('Train config is missing. Populating with default values')

    metrics_functions = list(zip(train_config['metrics'], get_metrics_by_names(train_config['metrics'])))

    if to_train:
        model = fit_chainer(config, iterator)

        if callable(getattr(model, 'train_on_batch', None)):
            _train_batches(model, iterator, train_config, metrics_functions)
        elif callable(getattr(model, 'fit_batches', None)):
            _fit_batches(model, iterator, train_config)
        elif callable(getattr(model, 'fit', None)):
            _fit(model, iterator, train_config)
        elif not isinstance(model, Chainer):
            log.warning('Nothing to train')

    if train_config['validate_best'] or train_config['test_best']:
        # try:
        #     model_config['load_path'] = model_config['save_path']
        # except KeyError:
        #     log.warning('No "save_path" parameter for the model, so "load_path" will not be renewed')
        model = build_model_from_config(config, load_trained=True)
        log.info('Testing the best saved model')

        if train_config['validate_best']:
            report = {
                'valid': _test_model(model, metrics_functions, iterator,
                                     train_config.get('batch_size', -1), 'valid',
                                     show_examples=train_config['show_examples'])
            }

            print(json.dumps(report, ensure_ascii=False))

        if train_config['test_best']:
            report = {
                'test': _test_model(model, metrics_functions, iterator,
                                    train_config.get('batch_size', -1), 'test',
                                    show_examples=train_config['show_examples'])
            }

            print(json.dumps(report, ensure_ascii=False))
Exemple #13
0
def from_params(params: Dict,
                mode: str = 'infer',
                serialized: Any = None,
                **kwargs) -> Union[Component, FunctionType]:
    """Builds and returns the Component from corresponding dictionary of parameters."""
    # what is passed in json:
    config_params = {k: _resolve(v) for k, v in params.items()}

    # get component by reference (if any)
    if 'ref' in config_params:
        try:
            component = _refs[config_params['ref']]
            if serialized is not None:
                component.deserialize(serialized)
            return component
        except KeyError:
            e = ConfigError(
                'Component with id "{id}" was referenced but not initialized'.
                format(id=config_params['ref']))
            log.exception(e)
            raise e

    elif 'config_path' in config_params:
        from deeppavlov.core.commands.infer import build_model
        refs = _refs.copy()
        _refs.clear()
        config = parse_config(expand_path(config_params['config_path']))
        model = build_model(config, serialized=serialized)
        _refs.clear()
        _refs.update(refs)
        try:
            _refs[config_params['id']] = model
        except KeyError:
            pass
        return model

    cls_name = config_params.pop('class_name', None)
    if not cls_name:
        e = ConfigError(
            'Component config has no `class_name` nor `ref` fields')
        log.exception(e)
        raise e
    obj = get_model(cls_name)

    if inspect.isclass(obj):
        # find the submodels params recursively
        config_params = {
            k: _init_param(v, mode)
            for k, v in config_params.items()
        }

        try:
            spec = inspect.getfullargspec(obj)
            if 'mode' in spec.args + spec.kwonlyargs or spec.varkw is not None:
                kwargs['mode'] = mode

            component = obj(**dict(config_params, **kwargs))
            try:
                _refs[config_params['id']] = component
            except KeyError:
                pass
        except Exception:
            log.exception("Exception in {}".format(obj))
            raise

        if serialized is not None:
            component.deserialize(serialized)
    else:
        component = obj

    return component
Exemple #14
0
def train_model_from_config(config_path: str):
    config = read_json(config_path)
    set_deeppavlov_root(config)

    reader_config = config['dataset_reader']
    reader = get_model(reader_config['name'])()
    data_path = expand_path(reader_config.get('data_path', ''))
    data = reader.read(data_path)

    dataset_config = config['dataset']
    dataset: Dataset = from_params(dataset_config, data=data)

    if 'chainer' in config:
        model = fit_chainer(config, dataset)
    else:
        vocabs = {}
        for vocab_param_name, vocab_config in config.get('vocabs', {}).items():
            v: Estimator = from_params(vocab_config, mode='train')
            vocabs[vocab_param_name] = _fit(v, dataset)

        model_config = config['model']
        model = from_params(model_config, vocabs=vocabs, mode='train')

    train_config = {
        'metrics': ['accuracy'],
        'validate_best': True,
        'test_best': True
    }

    try:
        train_config.update(config['train'])
    except KeyError:
        log.warning('Train config is missing. Populating with default values')

    metrics_functions = list(
        zip(train_config['metrics'],
            get_metrics_by_names(train_config['metrics'])))

    if callable(getattr(model, 'train_on_batch', None)):
        _train_batches(model, dataset, train_config, metrics_functions)
    elif callable(getattr(model, 'fit', None)):
        _fit(model, dataset, train_config)
    elif not isinstance(model, Chainer):
        log.warning('Nothing to train')

    if train_config['validate_best'] or train_config['test_best']:
        # try:
        #     model_config['load_path'] = model_config['save_path']
        # except KeyError:
        #     log.warning('No "save_path" parameter for the model, so "load_path" will not be renewed')
        model = build_model_from_config(config, load_trained=True)
        log.info('Testing the best saved model')

        if train_config['validate_best']:
            report = {
                'valid':
                _test_model(model, metrics_functions, dataset,
                            train_config.get('batch_size', -1), 'valid')
            }

            print(json.dumps(report, ensure_ascii=False))

        if train_config['test_best']:
            report = {
                'test':
                _test_model(model, metrics_functions, dataset,
                            train_config.get('batch_size', -1), 'test')
            }

            print(json.dumps(report, ensure_ascii=False))
Exemple #15
0
def train_model_from_config(config_path: str) -> None:
    config = read_json(config_path)
    set_deeppavlov_root(config)

    dataset_config = config.get('dataset', None)

    if dataset_config:
        config.pop('dataset')
        ds_type = dataset_config['type']
        if ds_type == 'classification':
            reader = {'name': 'basic_classification_reader'}
            iterator = {'name': 'basic_classification_iterator'}
            config['dataset_reader'] = {**dataset_config, **reader}
            config['dataset_iterator'] = {**dataset_config, **iterator}
        else:
            raise Exception("Unsupported dataset type: {}".format(ds_type))

    reader_config = config['dataset_reader']
    reader = get_model(reader_config['name'])()
    data_path = expand_path(reader_config.get('data_path', ''))
    kwargs = {
        k: v
        for k, v in reader_config.items() if k not in ['name', 'data_path']
    }
    data = reader.read(data_path, **kwargs)

    iterator_config = config['dataset_iterator']
    iterator: BasicDatasetIterator = from_params(iterator_config, data=data)

    if 'chainer' in config:
        model = fit_chainer(config, iterator)
    else:
        vocabs = config.get('vocabs', {})
        for vocab_param_name, vocab_config in vocabs.items():
            v: Estimator = from_params(vocab_config, mode='train')
            vocabs[vocab_param_name] = _fit(v, iterator)

        model_config = config['model']
        model = from_params(model_config, vocabs=vocabs, mode='train')

    train_config = {
        'metrics': ['accuracy'],
        'validate_best': True,
        'test_best': True
    }

    try:
        train_config.update(config['train'])
    except KeyError:
        log.warning('Train config is missing. Populating with default values')

    metrics_functions = list(
        zip(train_config['metrics'],
            get_metrics_by_names(train_config['metrics'])))

    if callable(getattr(model, 'train_on_batch', None)):
        _train_batches(model, iterator, train_config, metrics_functions)
    elif callable(getattr(model, 'fit', None)):
        _fit(model, iterator, train_config)
    elif not isinstance(model, Chainer):
        log.warning('Nothing to train')

    if train_config['validate_best'] or train_config['test_best']:
        # try:
        #     model_config['load_path'] = model_config['save_path']
        # except KeyError:
        #     log.warning('No "save_path" parameter for the model, so "load_path" will not be renewed')
        model = build_model_from_config(config, load_trained=True)
        log.info('Testing the best saved model')

        if train_config['validate_best']:
            report = {
                'valid':
                _test_model(model, metrics_functions, iterator,
                            train_config.get('batch_size', -1), 'valid')
            }

            print(json.dumps(report, ensure_ascii=False))

        if train_config['test_best']:
            report = {
                'test':
                _test_model(model, metrics_functions, iterator,
                            train_config.get('batch_size', -1), 'test')
            }

            print(json.dumps(report, ensure_ascii=False))