def test_tensors_check_duplicates(self): inputs = [Variable(name="x")] outputs = [Variable(name="x")] # Distinct tensors with the same name nodes = [ Node(op="Add", name="Test", inputs=inputs, outputs=outputs), ] graph = Graph(nodes=nodes, inputs=inputs, outputs=outputs) with pytest.raises(OnnxGraphSurgeonException): graph.tensors(check_duplicates=True)
def test_tensors_with_duplicates_check_disabled(self): inputs = [Variable(name="x")] outputs = [Variable(name="x")] # Distinct tensors with the same name nodes = [ Node(op="Add", name="Test", inputs=inputs, outputs=outputs), ] graph = Graph(nodes=nodes, inputs=inputs, outputs=outputs) # This should *not* throw graph.tensors(check_duplicates=False)
def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: """ Export an onnx-graphsurgeon Graph to an ONNX GraphProto. Args: graph (Graph): The graph to export. do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not. """ nodes = [OnnxExporter.export_node(node, do_type_check) for node in graph.nodes] inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs] outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs] tensor_map = graph.tensors() initializer = [OnnxExporter.export_tensor_proto(tensor) for tensor in tensor_map.values() if isinstance(tensor, Constant)] # Remove inputs and outputs to export ValueInfoProtos for tensor in graph.inputs + graph.outputs: if tensor.name in tensor_map: del tensor_map[tensor.name] # Omit tensors from value_info if we don't know their shape/dtype def has_value_info(tensor): return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None) value_info = [OnnxExporter.export_value_info_proto(tensor, do_type_check) for tensor in tensor_map.values() if has_value_info(tensor)] return onnx.helper.make_graph(nodes=nodes, name=graph.name, inputs=inputs, outputs=outputs, initializer=initializer, doc_string=graph.doc_string, value_info=value_info)
def test_tensors_includes_non_node_tensors(self): X = Constant("X", values=np.ones(shape=(64, 64), dtype=np.float32)) graph = Graph(inputs=[], outputs=[X]) tensor_map = graph.tensors() assert "X" in tensor_map assert tensor_map["X"] == X