Пример #1
0
    def test_tensors_check_duplicates(self):
        inputs = [Variable(name="x")]
        outputs = [Variable(name="x")] # Distinct tensors with the same name
        nodes = [
            Node(op="Add", name="Test", inputs=inputs, outputs=outputs),
        ]
        graph = Graph(nodes=nodes, inputs=inputs, outputs=outputs)

        with pytest.raises(OnnxGraphSurgeonException):
            graph.tensors(check_duplicates=True)
Пример #2
0
    def test_tensors_with_duplicates_check_disabled(self):
        inputs = [Variable(name="x")]
        outputs = [Variable(name="x")]  # Distinct tensors with the same name
        nodes = [
            Node(op="Add", name="Test", inputs=inputs, outputs=outputs),
        ]
        graph = Graph(nodes=nodes, inputs=inputs, outputs=outputs)

        # This should *not* throw
        graph.tensors(check_duplicates=False)
Пример #3
0
    def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto:
        """
        Export an onnx-graphsurgeon Graph to an ONNX GraphProto.

        Args:
            graph (Graph): The graph to export.

            do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not.
        """
        nodes = [OnnxExporter.export_node(node, do_type_check) for node in graph.nodes]
        inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs]
        outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs]
        tensor_map = graph.tensors()
        initializer = [OnnxExporter.export_tensor_proto(tensor) for tensor in tensor_map.values() if isinstance(tensor, Constant)]

        # Remove inputs and outputs to export ValueInfoProtos
        for tensor in graph.inputs + graph.outputs:
            if tensor.name in tensor_map:
                del tensor_map[tensor.name]

        # Omit tensors from value_info if we don't know their shape/dtype
        def has_value_info(tensor):
            return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None)
        value_info = [OnnxExporter.export_value_info_proto(tensor, do_type_check) for tensor in tensor_map.values() if has_value_info(tensor)]

        return onnx.helper.make_graph(nodes=nodes, name=graph.name, inputs=inputs, outputs=outputs, initializer=initializer, doc_string=graph.doc_string, value_info=value_info)
Пример #4
0
 def test_tensors_includes_non_node_tensors(self):
     X = Constant("X", values=np.ones(shape=(64, 64), dtype=np.float32))
     graph = Graph(inputs=[], outputs=[X])
     tensor_map = graph.tensors()
     assert "X" in tensor_map
     assert tensor_map["X"] == X