コード例 #1
0
ファイル: test_graph.py プロジェクト: phongphuhanam/TensorRT
    def test_copy_with_subgraph_dup_tensors(self):
        inp = Variable("input", dtype=np.float32, shape=(4, 5))
        graph = Graph(inputs=[inp])

        # We'll use shape to distinguish inner/outer tensor
        subgraph_inp = Variable("input", dtype=np.float32, shape=(1, 2))
        subgraph = Graph(inputs=[subgraph_inp])

        graph.outputs = [graph.nested(inp, subgraph)]

        graph_copy = graph.copy()
        assert graph_copy.nodes[0].attrs["body"].inputs[0].shape == (1, 2)
コード例 #2
0
ファイル: test_graph.py プロジェクト: phongphuhanam/TensorRT
    def test_copy_with_subgraph_dup_const_tensors(self):
        inp = Constant("input", values=np.ones(dtype=np.float32, shape=(4, 5)))
        graph = Graph()

        # We'll use shape to distinguish inner/outer tensor
        subgraph_inp = Constant("input",
                                values=np.ones(dtype=np.float32, shape=(1, 2)))
        subgraph = Graph()
        subgraph.outputs = [subgraph.identity(subgraph_inp)]

        graph.outputs = [graph.nested(inp, subgraph)]

        graph_copy = graph.copy()
        assert graph_copy.nodes[0].attrs["body"].nodes[0].inputs[0].shape == (
            1, 2)
コード例 #3
0
ファイル: test_graph.py プロジェクト: phongphuhanam/TensorRT
    def test_node_used_only_in_nested_graph(self):
        X = Variable("X", dtype=np.float32, shape=(1, ))
        Y = Variable("Y", dtype=np.float32, shape=(1, ))
        graph = Graph(inputs=[X, Y])

        X_p = graph.identity(
            X)  # X_p is only used by the subgraph, not in the outer graph.

        subgraph_inp = Variable("subgraph_input",
                                dtype=np.float32,
                                shape=(1, ))
        subgraph = Graph(inputs=[subgraph_inp])
        subgraph.outputs = [subgraph.add(subgraph_inp, X_p)]

        graph.outputs = [graph.nested(Y, subgraph)]

        graph.cleanup(remove_unused_graph_inputs=True)

        assert graph.nodes[0].op == "Identity"
        assert graph.nodes[0].inputs == [X]