Example #1
0
def test_conv_node_params(prunable_onnx_model):
    conv_node = [
        node for node in prunable_onnx_model.graph.node if node.op_type == "Conv"
    ][0]
    assert conv_node_params(prunable_onnx_model, conv_node, include_values=False) == (
        NodeParam("node1.weight", None),
        NodeParam("node1.bias", None),
    )
    params = conv_node_params(prunable_onnx_model, conv_node)
    assert params[0][1].shape == (2, 3, 3, 3)
    assert params[1][1].shape == (2,)
Example #2
0
def test_get_node_params(prunable_onnx_model):
    with pytest.raises(ValueError):
        get_node_params(prunable_onnx_model,
                        prunable_onnx_model.graph.node[-1])
    for node, expected_params in zip(
            prunable_onnx_model.graph.node[:-1],
        [
            (NodeParam("node1.weight", None), NodeParam("node1.bias", None)),
            (NodeParam("node2.weight", None), None),
            (NodeParam("node3.weight", None), None),
        ],
    ):
        assert (get_node_params(prunable_onnx_model,
                                node,
                                include_values=False) == expected_params)
Example #3
0
def test_matmul_node_params(prunable_onnx_model):
    matmul_node = [
        node for node in prunable_onnx_model.graph.node if node.op_type == "MatMul"
    ][0]
    assert matmul_node_params(
        prunable_onnx_model, matmul_node, include_values=False
    ) == (NodeParam("node3.weight", None), None)
    params = matmul_node_params(prunable_onnx_model, matmul_node)
    assert params[0][1].shape == (3, 4)
Example #4
0
def test_gemm_node_params(prunable_onnx_model):
    gemm_node = [
        node for node in prunable_onnx_model.graph.node if node.op_type == "Gemm"
    ][0]
    assert gemm_node_params(prunable_onnx_model, gemm_node, include_values=False) == (
        NodeParam("node2.weight", None),
        None,
    )
    params = gemm_node_params(prunable_onnx_model, gemm_node)
    assert params[0][1].shape == (12, 3)