Пример #1
0
def test_quantize_model_post_training_mnist():
    # Prepare model paths
    mnist_model_path = Zoo.search_models(
        domain="cv",
        sub_domain="classification",
        architecture="mnistnet",
        framework="pytorch",
    )[0].onnx_file.downloaded_path()
    quant_model_path = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False).name

    # Prepare sample validation dataset
    batch_size = 1
    val_dataset = MNISTDataset(train=False)
    input_dict = [{"input": img.numpy()} for (img, _) in val_dataset]
    data_loader = DataLoader(input_dict, None, batch_size)

    # Run calibration and quantization
    quantize_model_post_training(
        mnist_model_path, data_loader, quant_model_path, show_progress=False
    )

    # Verify that ResNet identity has no affect
    _test_resnet_identity_quant(quant_model_path, False, False)

    # Verify Convs and MatMuls are quantized
    _test_model_is_quantized(mnist_model_path, quant_model_path)

    # Verify quant model accuracy
    test_data_loader = DataLoader(input_dict, None, 1)  # initialize a new generator
    _test_quant_model_output(
        mnist_model_path, quant_model_path, test_data_loader, [0], batch_size
    )

    # Clean up
    os.remove(quant_model_path)
Пример #2
0
def test_search_models(model_args, other_args):
    models = Zoo.search_models(**model_args, **other_args)

    for model in models:
        for key, value in model_args.items():
            assert getattr(model, key) == value

    if "page_length" in other_args:
        assert len(models) <= other_args["page_length"]
Пример #3
0
def test_onnx_node_sparsities():
    # runs through nearly all other onnx functions imported above as well
    models = Zoo.search_models(
        domain="cv",
        sub_domain="classification",
        architecture="mobilenet_v1",
        dataset="imagenet",
        framework="pytorch",
        sparse_name="pruned",
        sparse_category="moderate",
        repo="sparseml",
    )
    assert len(models) > 0

    for model in models:
        file_path = model.onnx_file.downloaded_path()

        tot, nodes = onnx_nodes_sparsities(file_path)

        assert len(nodes) == 28

        assert isinstance(tot, SparsityMeasurement)
        assert tot.sparsity > 0.5
        assert tot.params_count == 4209088
        assert tot.params_zero_count > 0.5 * tot.params_count

        for node, val in nodes.items():
            assert isinstance(val, SparsityMeasurement)
            assert val.params_count > 0

            if "sections" not in node and "classifier" not in node:
                continue
            if (
                "depth" in node
                or "sections.0" in node
                or "sections_0" in node
                or "sections.1" in node
                or "sections_1" in node
                or "output" in node
            ):
                continue

            assert val.sparsity > 0.2
            assert val.sparsity < 0.95
            assert val.params_zero_count > 0