def main(): logging.basicConfig(level=logging.INFO) modules = load_config(MODULES_CONFIG_FILENAME) arg_parser = build_arg_parser(modules['arg_parser']) ote_args = vars(arg_parser.get_compression_parser(MODEL_TEMPLATE_FILENAME).parse_args()) if 'compression' not in modules: raise RuntimeError(f'Cannot make compression for the template that' f' does not have "compression" field in its modules' f' file {MODULES_CONFIG_FILENAME}') if not is_optimisation_enabled_in_template(MODEL_TEMPLATE_FILENAME): raise RuntimeError(f'Cannot make compression for the template that' f' does not enable any of compression flags') arg_converter = build_arg_converter(modules['arg_converter_map']) compress_args = arg_converter.convert_compress_args(ote_args) compression_arg_transformer = build_compression_arg_transformer(modules['compression']) compress_args, is_optimisation_enabled = \ compression_arg_transformer.process_args(MODEL_TEMPLATE_FILENAME, compress_args) if not is_optimisation_enabled: logging.warning('Optimization flags are not set -- compression is not made') return # Note that compression in this tool will be made by the same trainer, # as in the tool train.py # The difference is only in the argparser and in the NNCFConfigTransformer used to # transform the configuration file. trainer = build_trainer(modules['trainer']) trainer(**compress_args)
def main(): modules = load_config(MODULES_CONFIG_FILENAME) arg_parser = build_arg_parser(modules['arg_parser']) ote_args = vars(arg_parser.get_test_parser(MODEL_TEMPLATE_FILENAME).parse_args()) arg_converter = build_arg_converter(modules['arg_converter_map']) eval_args = arg_converter.convert_test_args(ote_args) if modules.get('compression') and is_optimisation_enabled_in_template(MODEL_TEMPLATE_FILENAME): compression_arg_transformer = build_compression_arg_transformer(modules['compression']) eval_args, _ = compression_arg_transformer.process_args(MODEL_TEMPLATE_FILENAME, eval_args) evaluator = build_evaluator(modules['evaluator']) evaluator(**eval_args)
def main(): logging.basicConfig(level=logging.INFO) modules = load_config(MODULES_CONFIG_FILENAME) arg_parser = build_arg_parser(modules['arg_parser']) ote_args = vars( arg_parser.get_export_parser(MODEL_TEMPLATE_FILENAME).parse_args()) if modules.get('compression') and is_optimisation_enabled_in_template( MODEL_TEMPLATE_FILENAME): compression_arg_transformer = build_compression_arg_transformer( modules['compression']) ote_args, _ = compression_arg_transformer.process_args( MODEL_TEMPLATE_FILENAME, ote_args) exporter = build_exporter(modules['exporter']) exporter(ote_args)
def apply_update_dict_params_to_template_file(template_file, template_update_dict, compression_cfg_update_dict): template_data = mmcv.Config.fromfile(template_file) template_data.dump(template_file + '.backup.yaml') assert is_optimisation_enabled_in_template(template_data), \ f'Template {template_file} does not contain optimisation part' if compression_cfg_update_dict: compression_cfg_rel_path = get_optimisation_config_from_template(template_data) compression_cfg_path = os.path.join(os.path.dirname(template_file), compression_cfg_rel_path) backup_compression_cfg_path = compression_cfg_path + '.BACKUP_FROM_TEST.json' compression_cfg = mmcv.Config.fromfile(compression_cfg_path) compression_cfg.dump(backup_compression_cfg_path) compression_cfg.merge_from_dict(compression_cfg_update_dict) compression_cfg.dump(compression_cfg_path) template_data.merge_from_dict(template_update_dict) template_data.dump(template_file)
def _is_optimisation_enabled_in_template(*args, **kwargs): from ote.modules.compression import is_optimisation_enabled_in_template return is_optimisation_enabled_in_template(*args, **kwargs)