Exemplo n.º 1
0
def load_model(model: Any, **kwargs) -> ModelProto:
    """
    Loads the model and saves it to a temporary file if necessary

    :param model: the model
    :param kwargs: additional arguments to pass if loading from a stub
    :return: the model loaded as a ModelProto
    """
    if not model:
        raise ValueError("Model must not be None type")

    if isinstance(model, str) and model.startswith("zoo:"):
        model = Zoo.load_model_from_stub(model, **kwargs)

    if isinstance(model, Model):
        # default to the main onnx file for the model
        model = model.onnx_file.downloaded_path()
    elif isinstance(model, File):
        # get the downloaded_path -- will auto download if not on local system
        model = model.downloaded_path()
    elif isinstance(model, ModelProto):
        return model

    if not isinstance(model, str):
        raise ValueError("unsupported type for model: {}".format(type(model)))

    if not os.path.exists(model):
        raise ValueError("model path must exist: given {}".format(model))

    return onnx.load(model)
Exemplo n.º 2
0
def download_and_verify(model: str, other_args: Optional[Dict] = None):
    if other_args is None:
        other_args = {
            "override_parent_path": os.path.join(CACHE_DIR, "test_download"),
        }
    model = Zoo.load_model_from_stub(model, **other_args)
    model.download(overwrite=True)
    validate_downloaded_model(model, check_other_args=other_args)
    shutil.rmtree(model.dir_path)
Exemplo n.º 3
0
def load_data(
    data: Any,
    model: Any = None,
    batch_size: int = 1,
    total_iterations: int = 0,
    **kwargs,
) -> Iterable[Tuple[Dict[str, Any], Any]]:
    """
    Creates a iteratable data loader for the given data.

    Acceptable types for data are:
    - a folder path containing numpy files
    - a list of file paths
    - a SparseML DataLoader
    - a SparseZoo DataLoader
    - an iterable
    - None type, in which case model must be passed

    :param data: data to use for benchmarking
    :param model: model to use for generating data
    :param batch_size: batch size
    :param total_iterations: total number of iterations
    :param kwargs: additional arguments to pass to the DataLoader
    :return: an iterable of data and labels
    """
    # Creates random data from model input shapes if data is not provided
    if not data:
        if not model:
            raise ValueError("must provide model or data")
        model = load_model(model)
        return DataLoader.from_model_random(
            model, batch_size, iter_steps=total_iterations
        )

    # If data is a SparseZoo stub, downloads model data
    if isinstance(data, str) and data.startswith("zoo:"):
        model_from_zoo = Zoo.load_model_from_stub(data)
        data = model_from_zoo.data_inputs.loader(
            batch_size, total_iterations, batch_as_list=False
        )

    # Imediately return the data if it is already a DataLoader
    if isinstance(data, DataLoader):
        return data

    # If data is a SparseZoo DataLoader, unbatches the dataloader and creates
    # DataLoader from it
    elif isinstance(data, SparseZooDataLoader):
        datasets = [
            SparseZooDataset(name, dataset) for name, dataset in data.datasets.items()
        ]
        data = SparseZooDataLoader(*datasets, batch_size=1, batch_as_list=False)
        data = [
            OrderedDict(
                [
                    (element, value.reshape(value.shape[1:]))
                    for element, value in entry.items()
                ]
            )
            for entry in data
        ]

    # If data is a dictionary of data shapes, creates DataLoader from random data
    elif isinstance(data, dict):
        is_dict_of_shapes = True
        for _, value in data.items():
            is_dict_of_shapes = is_dict_of_shapes and isinstance(value, tuple)
        if is_dict_of_shapes:
            return DataLoader.from_random(
                data,
                None,
                batch_size=batch_size,
                iter_steps=total_iterations,
                **kwargs,
            )

    # If data is a list of data shapes, creates DataLoader from random data
    elif isinstance(data, Iterable):
        element = next(iter(data))
        if isinstance(element, tuple):
            data_shapes = OrderedDict(
                (f"{index:04}", shape) for index, shape in enumerate(data)
            )
            return DataLoader.from_random(
                data_shapes,
                None,
                batch_size=batch_size,
                iter_steps=total_iterations,
                **kwargs,
            )
    return DataLoader(
        data, None, batch_size=batch_size, iter_steps=total_iterations, **kwargs
    )