def download_model_and_recipe(root_dir: str): """ Download pretrained model and a pruning recipe """ model_dir = os.path.join(root_dir, "resnet20_v1") # Load base model to prune base_zoo_model = Zoo.load_model( domain="cv", sub_domain="classification", architecture="resnet_v1", sub_architecture=20, framework="keras", repo="sparseml", dataset="cifar_10", training_scheme=None, optim_name="base", optim_category="none", optim_target=None, override_parent_path=model_dir, ) base_zoo_model.download() model_file_path = base_zoo_model.framework_files[0].downloaded_path() if not os.path.exists(model_file_path) or not model_file_path.endswith( ".h5"): raise RuntimeError("Model file not found: {}".format(model_file_path)) # Simply use the recipe stub recipe_file_path = ( "zoo:cv/classification/resnet_v1-20/keras/sparseml/cifar_10/pruned-conservative" ) return model_file_path, recipe_file_path
def main(): args = parse_args() logging.basicConfig(level=logging.INFO) if args.command == DOWNLOAD_COMMAND: LOGGER.info("Downloading files from model...") model = Zoo.load_model( domain=args.domain, sub_domain=args.sub_domain, architecture=args.architecture, sub_architecture=args.sub_architecture, framework=args.framework, repo=args.repo, dataset=args.dataset, training_scheme=args.training_scheme, optim_name=args.optim_name, optim_category=args.optim_category, optim_target=args.optim_target, release_version=args.release_version, override_parent_path=args.save_dir, ) model.download() print("Download results") print("====================") print("") print(model.stub) print(f"downloaded to {model.dir_path}") elif args.command == SEARCH_COMMAND: search(args)