def test_conv_node_params(prunable_onnx_model): conv_node = [ node for node in prunable_onnx_model.graph.node if node.op_type == "Conv" ][0] assert conv_node_params(prunable_onnx_model, conv_node, include_values=False) == ( NodeParam("node1.weight", None), NodeParam("node1.bias", None), ) params = conv_node_params(prunable_onnx_model, conv_node) assert params[0][1].shape == (2, 3, 3, 3) assert params[1][1].shape == (2,)
def test_get_node_params(prunable_onnx_model): with pytest.raises(ValueError): get_node_params(prunable_onnx_model, prunable_onnx_model.graph.node[-1]) for node, expected_params in zip( prunable_onnx_model.graph.node[:-1], [ (NodeParam("node1.weight", None), NodeParam("node1.bias", None)), (NodeParam("node2.weight", None), None), (NodeParam("node3.weight", None), None), ], ): assert (get_node_params(prunable_onnx_model, node, include_values=False) == expected_params)
def test_matmul_node_params(prunable_onnx_model): matmul_node = [ node for node in prunable_onnx_model.graph.node if node.op_type == "MatMul" ][0] assert matmul_node_params( prunable_onnx_model, matmul_node, include_values=False ) == (NodeParam("node3.weight", None), None) params = matmul_node_params(prunable_onnx_model, matmul_node) assert params[0][1].shape == (3, 4)
def test_gemm_node_params(prunable_onnx_model): gemm_node = [ node for node in prunable_onnx_model.graph.node if node.op_type == "Gemm" ][0] assert gemm_node_params(prunable_onnx_model, gemm_node, include_values=False) == ( NodeParam("node2.weight", None), None, ) params = gemm_node_params(prunable_onnx_model, gemm_node) assert params[0][1].shape == (12, 3)