Ejemplo n.º 1
0
def download_model_and_recipe(root_dir: str):
    """
    Download pretrained model and a pruning recipe
    """
    model_dir = os.path.join(root_dir, "resnet20_v1")

    # Load base model to prune
    base_zoo_model = Zoo.load_model(
        domain="cv",
        sub_domain="classification",
        architecture="resnet_v1",
        sub_architecture=20,
        framework="keras",
        repo="sparseml",
        dataset="cifar_10",
        training_scheme=None,
        optim_name="base",
        optim_category="none",
        optim_target=None,
        override_parent_path=model_dir,
    )
    base_zoo_model.download()
    model_file_path = base_zoo_model.framework_files[0].downloaded_path()
    if not os.path.exists(model_file_path) or not model_file_path.endswith(
            ".h5"):
        raise RuntimeError("Model file not found: {}".format(model_file_path))

    # Simply use the recipe stub
    recipe_file_path = (
        "zoo:cv/classification/resnet_v1-20/keras/sparseml/cifar_10/pruned-conservative"
    )

    return model_file_path, recipe_file_path
Ejemplo n.º 2
0
def load_model(model: Any, **kwargs) -> ModelProto:
    """
    Loads the model and saves it to a temporary file if necessary

    :param model: the model
    :param kwargs: additional arguments to pass if loading from a stub
    :return: the model loaded as a ModelProto
    """
    if not model:
        raise ValueError("Model must not be None type")

    if isinstance(model, str) and model.startswith("zoo:"):
        model = Zoo.load_model_from_stub(model, **kwargs)

    if isinstance(model, Model):
        # default to the main onnx file for the model
        model = model.onnx_file.downloaded_path()
    elif isinstance(model, File):
        # get the downloaded_path -- will auto download if not on local system
        model = model.downloaded_path()
    elif isinstance(model, ModelProto):
        return model

    if not isinstance(model, str):
        raise ValueError("unsupported type for model: {}".format(type(model)))

    if not os.path.exists(model):
        raise ValueError("model path must exist: given {}".format(model))

    return onnx.load(model)
Ejemplo n.º 3
0
Archivo: main.py Proyecto: PIlotcnc/new
def main():
    args = parse_args()
    logging.basicConfig(level=logging.INFO)

    if args.command == DOWNLOAD_COMMAND:
        LOGGER.info("Downloading files from model...")
        model = Zoo.download_model(
            domain=args.domain,
            sub_domain=args.sub_domain,
            architecture=args.architecture,
            sub_architecture=args.sub_architecture,
            framework=args.framework,
            repo=args.repo,
            dataset=args.dataset,
            training_scheme=args.training_scheme,
            sparse_name=args.sparse_name,
            sparse_category=args.sparse_category,
            sparse_target=args.sparse_target,
            release_version=args.release_version,
            override_parent_path=args.save_dir,
            overwrite=args.overwrite,
        )

        print("Download results")
        print("====================")
        print("")
        print(f"downloaded to {model.dir_path}")
    elif args.command == SEARCH_COMMAND:
        search(args)
Ejemplo n.º 4
0
def download_and_verify(model: str, other_args: Optional[Dict] = None):
    if other_args is None:
        other_args = {
            "override_parent_path": os.path.join(CACHE_DIR, "test_download"),
        }
    model = Zoo.load_model_from_stub(model, **other_args)
    model.download(overwrite=True)
    validate_downloaded_model(model, check_other_args=other_args)
    shutil.rmtree(model.dir_path)
Ejemplo n.º 5
0
def _get_models(domain, sub_domain) -> List[str]:
    page = 1
    models = []
    while True:
        results = Zoo.search_models(domain, sub_domain, page=page)
        if len(results) == 0:
            break
        models.extend(results)
        page += 1
    return [model.stub for model in models]
Ejemplo n.º 6
0
Archivo: main.py Proyecto: PIlotcnc/new
def search(args):
    LOGGER.info("loading available models...")
    models = Zoo.search_models(
        domain=args.domain,
        sub_domain=args.sub_domain,
        architecture=args.architecture,
        sub_architecture=args.sub_architecture,
        framework=args.framework,
        repo=args.repo,
        dataset=args.dataset,
        training_scheme=args.training_scheme,
        sparse_name=args.sparse_name,
        sparse_category=args.sparse_category,
        sparse_target=args.sparse_target,
        release_version=args.release_version,
        page=args.page,
        page_length=args.page_length,
    )

    print("Search results")
    print("====================")
    result_start = (args.page - 1) * args.page_length + 1
    result_end = (args.page) * args.page_length
    print(f"Showing results {result_start} - {result_end}")
    print("")

    for index, model in enumerate(models):
        result_index = (index + 1) + (args.page_length * (args.page - 1))
        header = f"{result_index}) {model.display_name}"
        print(header)
        print("-------------------------")
        print(f"Model Description: {model.display_description}")
        print("")
        print(f"Framework: {model.framework}")
        print("")
        print(f"Repository: {model.repo}")
        print("")
        tag_string = ", ".join([tag.display_name for tag in model.tags])
        print(f"Tags: {tag_string}")
        print("")
        print(f"Download Command: {_get_command_from_model(model)}")
        print("")
        print("")
Ejemplo n.º 7
0
def main():
    args = parse_args()
    logging.basicConfig(level=logging.INFO)

    LOGGER.info("Downloading files from model '{}'".format(args.model_stub))

    if not isinstance(args.model_stub, str):
        raise ValueError("Model stub be a string")

    if not args.model_stub.startswith("zoo:"):
        raise ValueError("Model stub must start with 'zoo:'")

    model = Zoo.download_model_from_stub(
        stub=args.model_stub,
        override_parent_path=args.save_dir,
        overwrite=args.overwrite,
    )

    print("Download results")
    print("====================")
    print()
    print(f"{model.display_name} downloaded to {model.dir_path}")
Ejemplo n.º 8
0
def load_data(
    data: Any,
    model: Any = None,
    batch_size: int = 1,
    total_iterations: int = 0,
    **kwargs,
) -> Iterable[Tuple[Dict[str, Any], Any]]:
    """
    Creates a iteratable data loader for the given data.

    Acceptable types for data are:
    - a folder path containing numpy files
    - a list of file paths
    - a SparseML DataLoader
    - a SparseZoo DataLoader
    - an iterable
    - None type, in which case model must be passed

    :param data: data to use for benchmarking
    :param model: model to use for generating data
    :param batch_size: batch size
    :param total_iterations: total number of iterations
    :param kwargs: additional arguments to pass to the DataLoader
    :return: an iterable of data and labels
    """
    # Creates random data from model input shapes if data is not provided
    if not data:
        if not model:
            raise ValueError("must provide model or data")
        model = load_model(model)
        return DataLoader.from_model_random(
            model, batch_size, iter_steps=total_iterations
        )

    # If data is a SparseZoo stub, downloads model data
    if isinstance(data, str) and data.startswith("zoo:"):
        model_from_zoo = Zoo.load_model_from_stub(data)
        data = model_from_zoo.data_inputs.loader(
            batch_size, total_iterations, batch_as_list=False
        )

    # Imediately return the data if it is already a DataLoader
    if isinstance(data, DataLoader):
        return data

    # If data is a SparseZoo DataLoader, unbatches the dataloader and creates
    # DataLoader from it
    elif isinstance(data, SparseZooDataLoader):
        datasets = [
            SparseZooDataset(name, dataset) for name, dataset in data.datasets.items()
        ]
        data = SparseZooDataLoader(*datasets, batch_size=1, batch_as_list=False)
        data = [
            OrderedDict(
                [
                    (element, value.reshape(value.shape[1:]))
                    for element, value in entry.items()
                ]
            )
            for entry in data
        ]

    # If data is a dictionary of data shapes, creates DataLoader from random data
    elif isinstance(data, dict):
        is_dict_of_shapes = True
        for _, value in data.items():
            is_dict_of_shapes = is_dict_of_shapes and isinstance(value, tuple)
        if is_dict_of_shapes:
            return DataLoader.from_random(
                data,
                None,
                batch_size=batch_size,
                iter_steps=total_iterations,
                **kwargs,
            )

    # If data is a list of data shapes, creates DataLoader from random data
    elif isinstance(data, Iterable):
        element = next(iter(data))
        if isinstance(element, tuple):
            data_shapes = OrderedDict(
                (f"{index:04}", shape) for index, shape in enumerate(data)
            )
            return DataLoader.from_random(
                data_shapes,
                None,
                batch_size=batch_size,
                iter_steps=total_iterations,
                **kwargs,
            )
    return DataLoader(
        data, None, batch_size=batch_size, iter_steps=total_iterations, **kwargs
    )