def test_quantize_model_post_training_mnist(): # Prepare model paths mnist_model_path = Zoo.search_models( domain="cv", sub_domain="classification", architecture="mnistnet", framework="pytorch", )[0].onnx_file.downloaded_path() quant_model_path = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False).name # Prepare sample validation dataset batch_size = 1 val_dataset = MNISTDataset(train=False) input_dict = [{"input": img.numpy()} for (img, _) in val_dataset] data_loader = DataLoader(input_dict, None, batch_size) # Run calibration and quantization quantize_model_post_training( mnist_model_path, data_loader, quant_model_path, show_progress=False ) # Verify that ResNet identity has no affect _test_resnet_identity_quant(quant_model_path, False, False) # Verify Convs and MatMuls are quantized _test_model_is_quantized(mnist_model_path, quant_model_path) # Verify quant model accuracy test_data_loader = DataLoader(input_dict, None, 1) # initialize a new generator _test_quant_model_output( mnist_model_path, quant_model_path, test_data_loader, [0], batch_size ) # Clean up os.remove(quant_model_path)
def test_search_models(model_args, other_args): models = Zoo.search_models(**model_args, **other_args) for model in models: for key, value in model_args.items(): assert getattr(model, key) == value if "page_length" in other_args: assert len(models) <= other_args["page_length"]
def test_onnx_node_sparsities(): # runs through nearly all other onnx functions imported above as well models = Zoo.search_models( domain="cv", sub_domain="classification", architecture="mobilenet_v1", dataset="imagenet", framework="pytorch", sparse_name="pruned", sparse_category="moderate", repo="sparseml", ) assert len(models) > 0 for model in models: file_path = model.onnx_file.downloaded_path() tot, nodes = onnx_nodes_sparsities(file_path) assert len(nodes) == 28 assert isinstance(tot, SparsityMeasurement) assert tot.sparsity > 0.5 assert tot.params_count == 4209088 assert tot.params_zero_count > 0.5 * tot.params_count for node, val in nodes.items(): assert isinstance(val, SparsityMeasurement) assert val.params_count > 0 if "sections" not in node and "classifier" not in node: continue if ( "depth" in node or "sections.0" in node or "sections_0" in node or "sections.1" in node or "sections_1" in node or "output" in node ): continue assert val.sparsity > 0.2 assert val.sparsity < 0.95 assert val.params_zero_count > 0