예제 #1
0
def main():
    logging.basicConfig(level=logging.INFO)
    modules = load_config(MODULES_CONFIG_FILENAME)

    arg_parser = build_arg_parser(modules['arg_parser'])
    ote_args = vars(arg_parser.get_compression_parser(MODEL_TEMPLATE_FILENAME).parse_args())

    if 'compression' not in modules:
        raise RuntimeError(f'Cannot make compression for the template that'
                           f' does not have "compression" field in its modules'
                           f' file {MODULES_CONFIG_FILENAME}')
    if not is_optimisation_enabled_in_template(MODEL_TEMPLATE_FILENAME):
        raise RuntimeError(f'Cannot make compression for the template that'
                           f' does not enable any of compression flags')

    arg_converter = build_arg_converter(modules['arg_converter_map'])
    compress_args = arg_converter.convert_compress_args(ote_args)

    compression_arg_transformer = build_compression_arg_transformer(modules['compression'])
    compress_args, is_optimisation_enabled = \
            compression_arg_transformer.process_args(MODEL_TEMPLATE_FILENAME, compress_args)

    if not is_optimisation_enabled:
        logging.warning('Optimization flags are not set -- compression is not made')
        return

    # Note that compression in this tool will be made by the same trainer,
    # as in the tool train.py
    # The difference is only in the argparser and in the NNCFConfigTransformer used to
    # transform the configuration file.
    trainer = build_trainer(modules['trainer'])
    trainer(**compress_args)
예제 #2
0
def main():
    modules = load_config(MODULES_CONFIG_FILENAME)

    arg_parser = build_arg_parser(modules['arg_parser'])
    ote_args = vars(arg_parser.get_test_parser(MODEL_TEMPLATE_FILENAME).parse_args())

    arg_converter = build_arg_converter(modules['arg_converter_map'])
    eval_args = arg_converter.convert_test_args(ote_args)

    if modules.get('compression') and is_optimisation_enabled_in_template(MODEL_TEMPLATE_FILENAME):
        compression_arg_transformer = build_compression_arg_transformer(modules['compression'])
        eval_args, _ = compression_arg_transformer.process_args(MODEL_TEMPLATE_FILENAME, eval_args)

    evaluator = build_evaluator(modules['evaluator'])
    evaluator(**eval_args)
예제 #3
0
def main():
    logging.basicConfig(level=logging.INFO)
    modules = load_config(MODULES_CONFIG_FILENAME)

    arg_parser = build_arg_parser(modules['arg_parser'])
    ote_args = vars(
        arg_parser.get_export_parser(MODEL_TEMPLATE_FILENAME).parse_args())

    if modules.get('compression') and is_optimisation_enabled_in_template(
            MODEL_TEMPLATE_FILENAME):
        compression_arg_transformer = build_compression_arg_transformer(
            modules['compression'])
        ote_args, _ = compression_arg_transformer.process_args(
            MODEL_TEMPLATE_FILENAME, ote_args)

    exporter = build_exporter(modules['exporter'])
    exporter(ote_args)
예제 #4
0
        def apply_update_dict_params_to_template_file(template_file, template_update_dict, compression_cfg_update_dict):
            template_data = mmcv.Config.fromfile(template_file)
            template_data.dump(template_file + '.backup.yaml')

            assert is_optimisation_enabled_in_template(template_data), \
                    f'Template {template_file} does not contain optimisation part'

            if compression_cfg_update_dict:
                compression_cfg_rel_path = get_optimisation_config_from_template(template_data)
                compression_cfg_path = os.path.join(os.path.dirname(template_file), compression_cfg_rel_path)
                backup_compression_cfg_path = compression_cfg_path + '.BACKUP_FROM_TEST.json'

                compression_cfg = mmcv.Config.fromfile(compression_cfg_path)
                compression_cfg.dump(backup_compression_cfg_path)

                compression_cfg.merge_from_dict(compression_cfg_update_dict)
                compression_cfg.dump(compression_cfg_path)


            template_data.merge_from_dict(template_update_dict)
            template_data.dump(template_file)
예제 #5
0
def _is_optimisation_enabled_in_template(*args, **kwargs):
    from ote.modules.compression import is_optimisation_enabled_in_template
    return is_optimisation_enabled_in_template(*args, **kwargs)