Beispiel #1
0
def dataloader_models(request) -> DataloaderModelFixture:
    model_args, input_shapes, output_shapes, data_types = request.param
    model = Zoo.load_model(**model_args)
    model_path = model.onnx_file.downloaded_path()

    return DataloaderModelFixture(model_path, input_shapes, output_shapes,
                                  data_types)
Beispiel #2
0
def test_search_similar_models(model_args, other_args):
    model = Zoo.load_model(**model_args, **other_args)
    similar = Zoo.search_similar_models(model)
    assert len(similar) > 0

    for sim in similar:
        assert sim
        assert sim.domain == model.domain
        assert sim.sub_domain == model.sub_domain
        assert sim.architecture == model.architecture
        assert sim.sub_architecture == model.sub_architecture
Beispiel #3
0
def onnx_models_with_data(request) -> OnnxModelDataFixture:
    model_args = request.param
    model = Zoo.load_model(**model_args)
    model_path = model.onnx_file.downloaded_path()
    data_paths = [data_file.downloaded_path() for data_file in model.data.values()]
    inputs_paths = None
    outputs_paths = None
    for path in data_paths:
        if "sample-inputs" in path:
            inputs_paths = path
        elif "sample-outputs" in path:
            outputs_paths = path
    return OnnxModelDataFixture(model_path, inputs_paths, outputs_paths)
Beispiel #4
0
def onnx_repo_models(request) -> OnnxRepoModelFixture:
    model_args, model_name = request.param
    model = Zoo.load_model(**model_args)
    model_path = model.onnx_file.downloaded_path()
    data_paths = [data_file.downloaded_path() for data_file in model.data.values()]

    input_paths = None
    output_paths = None
    for path in data_paths:
        if "sample-inputs" in path:
            input_paths = path
        elif "sample-outputs" in path:
            output_paths = path
    return OnnxRepoModelFixture(model_path, model_name, input_paths, output_paths)
Beispiel #5
0
def analyzer_models_repo(request):
    model_args, output_path = request.param
    output_path = os.path.join(RELATIVE_PATH, "test_analyzer_model_data", output_path)
    model = Zoo.load_model(**model_args)
    model_path = model.onnx_file.downloaded_path()

    if GENERATE_TEST_FILES:
        analyzer = ModelAnalyzer(model_path)
        analyzer.save_json(output_path)

    output = {}
    with open(output_path) as output_file:
        output = dict(json.load(output_file))

    return model_path, output
Beispiel #6
0
def test_search_optimized_models(model_args, other_args):
    model = Zoo.load_model(**model_args, **other_args)
    optimized = Zoo.search_optimized_models(model)
    assert len(optimized) > 0

    for sim in optimized:
        assert sim
        assert sim.domain == model.domain
        assert sim.sub_domain == model.sub_domain
        assert sim.architecture == model.architecture
        assert sim.sub_architecture == model.sub_architecture
        assert sim.framework == model.framework
        assert sim.repo == model.repo
        assert sim.dataset == model.dataset
        assert sim.training_scheme == model.training_scheme
Beispiel #7
0
def test_search_sparse_models(model_args, other_args):
    model = Zoo.load_model(**model_args, **other_args)
    sparse = Zoo.search_sparse_models(model)
    assert len(sparse) > 0

    for sim in sparse:
        assert sim
        assert not sim.is_base
        assert sim.domain == model.domain
        assert sim.sub_domain == model.sub_domain
        assert sim.architecture == model.architecture
        assert sim.sub_architecture == model.sub_architecture
        assert sim.framework == model.framework
        assert sim.repo == model.repo
        assert sim.dataset == model.dataset
        assert sim.training_scheme == model.training_scheme
Beispiel #8
0
def test_quantize_model_post_training_resnet50_imagenette():
    # Prepare model paths
    resnet50_imagenette_path = Zoo.load_model(
        domain="cv",
        sub_domain="classification",
        architecture="resnet_v1",
        sub_architecture="50",
        framework="pytorch",
        repo="sparseml",
        dataset="imagenette",
        training_scheme=None,
        sparse_name="base",
        sparse_category="none",
        sparse_target=None,
    ).onnx_file.downloaded_path()
    quant_model_path = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False).name

    # Prepare sample validation dataset
    batch_size = 1
    val_dataset = ImagenetteDataset(train=False, dataset_size=ImagenetteSize.s320)
    input_dict = [{"input": img.numpy()} for (img, _) in val_dataset]
    data_loader = DataLoader(input_dict, None, batch_size)

    # Run calibration and quantization
    quantize_model_post_training(
        resnet50_imagenette_path,
        data_loader,
        quant_model_path,
        show_progress=False,
        run_extra_opt=False,
    )

    # Verify that ResNet identity optimization is successful and save output for testing
    _test_resnet_identity_quant(quant_model_path, True, True)

    # Verify Convs and MatMuls are quantized
    _test_model_is_quantized(resnet50_imagenette_path, quant_model_path)

    # Verify quant model accuracy
    test_data_loader = DataLoader(input_dict, None, 1)  # initialize a new generator
    _test_quant_model_output(
        resnet50_imagenette_path, quant_model_path, test_data_loader, [1], batch_size
    )

    # Clean up
    os.remove(quant_model_path)
Beispiel #9
0
def test_search_sparse_recipes(model_args, other_args, other_recipe_args):
    model = Zoo.load_model(**model_args, **other_args)
    recipes = Zoo.search_sparse_recipes(model, **other_recipe_args)
    assert len(recipes) > 0

    for recipe in recipes:
        assert recipe
        assert recipe.model_metadata.domain == model.domain
        assert recipe.model_metadata.sub_domain == model.sub_domain
        assert recipe.model_metadata.architecture == model.architecture
        assert recipe.model_metadata.sub_architecture == model.sub_architecture
        assert recipe.model_metadata.framework == model.framework
        assert recipe.model_metadata.repo == model.repo
        assert recipe.model_metadata.dataset == model.dataset
        assert recipe.model_metadata.training_scheme == model.training_scheme

        if "recipe_type" in other_recipe_args:
            assert recipe.recipe_type == other_recipe_args["recipe_type"]
Beispiel #10
0
    def create_zoo_model(
        key: str,
        pretrained: Union[bool, str] = True,
        pretrained_dataset: str = None,
    ) -> Model:
        """
        Create a sparsezoo Model for the desired model in the zoo

        :param key: the model key (name) to retrieve
        :param pretrained: True to load pretrained weights; to load a specific version
            give a string with the name of the version (optim, optim-perf), default True
        :param pretrained_dataset: The dataset to load for the model
        :return: the sparsezoo Model reference for the given model
        """
        if key not in ModelRegistry._CONSTRUCTORS:
            raise ValueError(
                "key {} is not in the model registry; available: {}".format(
                    key, ModelRegistry._CONSTRUCTORS
                )
            )

        attributes = ModelRegistry._ATTRIBUTES[key]

        optim_name, optim_category, optim_target = parse_optimization_str(
            pretrained if isinstance(pretrained, str) else attributes.default_desc
        )

        return Zoo.load_model(
            attributes.domain,
            attributes.sub_domain,
            attributes.architecture,
            attributes.sub_architecture,
            KERAS_FRAMEWORK,
            attributes.repo_source,
            attributes.default_dataset
            if pretrained_dataset is None
            else pretrained_dataset,
            None,
            optim_name,
            optim_category,
            optim_target,
        )
Beispiel #11
0
def test_load_model(model_args, other_args):
    model = Zoo.load_model(**model_args, **other_args)
    model.download(overwrite=True)
    validate_downloaded_model(model, model_args, other_args)
    shutil.rmtree(model.dir_path)