コード例 #1
0
    def __call__(self, template_path, kwargs_nncf):
        if not is_optimisation_enabled_in_template(template_path):
            logging.warning(
                'WARNING: optimisation class is called for a template that does not enable optimisation.'
                ' This must not be happened in OTE.')
            return {}
        template = load_config(template_path)
        optimisation_template = template[OPTIMISATION_PART_NAME]

        optimisation_template = copy(optimisation_template)
        optimisation_config_path = get_optimisation_config_from_template(
            template)

        optimisation_parts_to_choose = []
        for k in POSSIBLE_NNCF_PARTS:
            should_pick = bool(kwargs_nncf.get(k))
            if should_pick:
                optimisation_parts_to_choose.append(k)

        if not optimisation_parts_to_choose:
            return {}

        nncf_config_part = self._merge_nncf_optimisation_parts(
            optimisation_config_path, optimisation_parts_to_choose)
        return nncf_config_part
コード例 #2
0
def main():
    logging.basicConfig(level=logging.INFO)
    modules = load_config(MODULES_CONFIG_FILENAME)

    arg_parser = build_arg_parser(modules['arg_parser'])
    ote_args = vars(arg_parser.get_compression_parser(MODEL_TEMPLATE_FILENAME).parse_args())

    if 'compression' not in modules:
        raise RuntimeError(f'Cannot make compression for the template that'
                           f' does not have "compression" field in its modules'
                           f' file {MODULES_CONFIG_FILENAME}')
    if not is_optimisation_enabled_in_template(MODEL_TEMPLATE_FILENAME):
        raise RuntimeError(f'Cannot make compression for the template that'
                           f' does not enable any of compression flags')

    arg_converter = build_arg_converter(modules['arg_converter_map'])
    compress_args = arg_converter.convert_compress_args(ote_args)

    compression_arg_transformer = build_compression_arg_transformer(modules['compression'])
    compress_args, is_optimisation_enabled = \
            compression_arg_transformer.process_args(MODEL_TEMPLATE_FILENAME, compress_args)

    if not is_optimisation_enabled:
        logging.warning('Optimization flags are not set -- compression is not made')
        return

    # Note that compression in this tool will be made by the same trainer,
    # as in the tool train.py
    # The difference is only in the argparser and in the NNCFConfigTransformer used to
    # transform the configuration file.
    trainer = build_trainer(modules['trainer'])
    trainer(**compress_args)
コード例 #3
0
def is_optimisation_enabled_in_template(template):
    """
    The function returns if a model template contains
    'optimisation' section; also the function
    validates if the section is correct
    The function receives as the parameter either
    template path or template dict read from file
    """
    if isinstance(template, str):
        template = load_config(template)
    optimisation_template = template.get(OPTIMISATION_PART_NAME)
    if not optimisation_template:
        return False
    assert isinstance(optimisation_template, dict), (
        f'Error: optimisation part of template is not a dict: template["optimisation"]={optimisation_template}'
    )
    unknown_keys = set(optimisation_template.keys()) - POSSIBLE_NNCF_PARTS
    if unknown_keys:
        raise RuntimeError(
            f'Optimisation parameters contain unknown keys: {list(unknown_keys)}'
        )
    optimisation_configs = _get_optimisation_configs_from_template(template)
    if not optimisation_configs:
        raise RuntimeError(
            f'Optimisation parameters do not contain the field "{COMPRESSION_CONFIG_KEY}"'
        )
    if len(optimisation_configs) > 1:
        raise RuntimeError(
            f'Wrong config: the optimisation config contains different config files: {optimisation_configs}'
        )
    return True
コード例 #4
0
ファイル: train.py プロジェクト: zk886/training_extensions
def main():
    logging.basicConfig(level=logging.INFO)
    modules = load_config(MODULES_CONFIG_FILENAME)

    arg_parser = build_arg_parser(modules['arg_parser'])
    ote_args = vars(arg_parser.get_train_parser(MODEL_TEMPLATE_FILENAME).parse_args())

    arg_converter = build_arg_converter(modules['arg_converter'])
    train_args = arg_converter.convert_train_args(MODEL_TEMPLATE_FILENAME, ote_args)

    # Note that compression args transformer is not applied here,
    # since NNCF compression (if it is enabled) will be applied
    # later, when the training is finished.

    trainer = build_trainer(modules['trainer'])
    trainer(**train_args)

    if modules.get('compression') and is_compression_enabled_in_template(MODEL_TEMPLATE_FILENAME):
        # TODO: think on the case if compression is enabled in template.yaml, but modules does not contain 'compression'

        latest_snapshot = trainer.get_latest_snapshot()
        if not latest_snapshot:
            raise RuntimeError('Cannot find latest snapshot to make compression after training')

        compress_args = arg_converter.convert_train_args_to_compress_args(MODEL_TEMPLATE_FILENAME, ote_args)
        arg_converter.update_converted_args_to_load_from_snapshot(compress_args, latest_snapshot)

        compression_arg_transformer = build_compression_arg_transformer(modules['compression'])
        compress_args = compression_arg_transformer.process_args(MODEL_TEMPLATE_FILENAME, compress_args)

        compress_trainer = build_trainer(modules['trainer'])
        compress_trainer(**compress_args)
コード例 #5
0
def is_compression_enabled_in_template(template_path):
    template = load_config(template_path)
    compression_template = template.get('compression')
    if not compression_template:
        return False
    assert isinstance(compression_template, dict), (
            f'Error: compression part of template is not a dict: template["compression"]={compression_template}')
    possible_keys = POSSIBLE_NNCF_PARTS | {COMPRESSION_CONFIG_KEY}
    unknown_keys = set(compression_template.keys()) - possible_keys
    if unknown_keys:
        raise RuntimeError(f'Compression parameters contain unknown keys: {list(unknown_keys)}')
    if COMPRESSION_CONFIG_KEY not in compression_template:
        raise RuntimeError(f'Compression parameters do not contain the field "{COMPRESSION_CONFIG_KEY}"')
    is_compression_enabled = any(compression_template.get(key) for key in POSSIBLE_NNCF_PARTS)
    return is_compression_enabled
コード例 #6
0
def main():
    modules = load_config(MODULES_CONFIG_FILENAME)

    arg_parser = build_arg_parser(modules['arg_parser'])
    ote_args = vars(arg_parser.get_test_parser(MODEL_TEMPLATE_FILENAME).parse_args())

    arg_converter = build_arg_converter(modules['arg_converter_map'])
    eval_args = arg_converter.convert_test_args(ote_args)

    if modules.get('compression') and is_optimisation_enabled_in_template(MODEL_TEMPLATE_FILENAME):
        compression_arg_transformer = build_compression_arg_transformer(modules['compression'])
        eval_args, _ = compression_arg_transformer.process_args(MODEL_TEMPLATE_FILENAME, eval_args)

    evaluator = build_evaluator(modules['evaluator'])
    evaluator(**eval_args)
コード例 #7
0
def main():
    logging.basicConfig(level=logging.INFO)
    modules = load_config(MODULES_CONFIG_FILENAME)

    arg_parser = build_arg_parser(modules['arg_parser'])
    ote_args = vars(arg_parser.get_train_parser(MODEL_TEMPLATE_FILENAME).parse_args())

    arg_converter = build_arg_converter(modules['arg_converter_map'])
    train_args = arg_converter.convert_train_args(ote_args)

    # Note that compression args transformer is not applied here,
    # since NNCF compression (if it is enabled) will be applied
    # later, when the training is finished.

    trainer = build_trainer(modules['trainer'])
    trainer(**train_args)
コード例 #8
0
ファイル: export.py プロジェクト: zk886/training_extensions
def main():
    logging.basicConfig(level=logging.INFO)
    modules = load_config(MODULES_CONFIG_FILENAME)

    arg_parser = build_arg_parser(modules['arg_parser'])
    ote_args = vars(
        arg_parser.get_export_parser(MODEL_TEMPLATE_FILENAME).parse_args())

    if modules.get('compression') and is_compression_enabled_in_template(
            MODEL_TEMPLATE_FILENAME):
        compression_arg_transformer = build_compression_arg_transformer(
            modules['compression'])
        ote_args = compression_arg_transformer.process_args(
            MODEL_TEMPLATE_FILENAME, ote_args)

    exporter = build_exporter(modules['exporter'])
    exporter(ote_args)
コード例 #9
0
    def __call__(self, template_path):
        assert is_compression_enabled_in_template(template_path), (
                'Error: compression class is called for a template that does not enable compression.'
                ' This must not be happened in OTE.')
        template = load_config(template_path)
        compression_template = template['compression']

        compression_template = copy(compression_template)
        compression_config_path = compression_template.pop(COMPRESSION_CONFIG_KEY)

        compression_parts_to_choose = []
        for k, v in compression_template.items():
            should_pick = bool(v)
            if should_pick:
                compression_parts_to_choose.append(k)

        if not compression_parts_to_choose:
            return {}

        nncf_config_part = self._merge_nncf_compression_parts(compression_config_path, compression_parts_to_choose)
        return nncf_config_part