示例#1
0
def download_model_and_recipe(root_dir: str):
    """
    Download pretrained model and a pruning recipe
    """
    model_dir = os.path.join(root_dir, "resnet20_v1")

    # Load base model to prune
    base_zoo_model = Zoo.load_model(
        domain="cv",
        sub_domain="classification",
        architecture="resnet_v1",
        sub_architecture=20,
        framework="keras",
        repo="sparseml",
        dataset="cifar_10",
        training_scheme=None,
        optim_name="base",
        optim_category="none",
        optim_target=None,
        override_parent_path=model_dir,
    )
    base_zoo_model.download()
    model_file_path = base_zoo_model.framework_files[0].downloaded_path()
    if not os.path.exists(model_file_path) or not model_file_path.endswith(
            ".h5"):
        raise RuntimeError("Model file not found: {}".format(model_file_path))

    # Simply use the recipe stub
    recipe_file_path = (
        "zoo:cv/classification/resnet_v1-20/keras/sparseml/cifar_10/pruned-conservative"
    )

    return model_file_path, recipe_file_path
示例#2
0
def main():
    args = parse_args()
    logging.basicConfig(level=logging.INFO)

    if args.command == DOWNLOAD_COMMAND:
        LOGGER.info("Downloading files from model...")
        model = Zoo.load_model(
            domain=args.domain,
            sub_domain=args.sub_domain,
            architecture=args.architecture,
            sub_architecture=args.sub_architecture,
            framework=args.framework,
            repo=args.repo,
            dataset=args.dataset,
            training_scheme=args.training_scheme,
            optim_name=args.optim_name,
            optim_category=args.optim_category,
            optim_target=args.optim_target,
            release_version=args.release_version,
            override_parent_path=args.save_dir,
        )
        model.download()

        print("Download results")
        print("====================")
        print("")
        print(model.stub)
        print(f"downloaded to {model.dir_path}")
    elif args.command == SEARCH_COMMAND:
        search(args)