コード例 #1
0
ファイル: test_graph.py プロジェクト: phongphuhanam/TensorRT
    def test_copy_with_subgraph_dup_const_tensors(self):
        inp = Constant("input", values=np.ones(dtype=np.float32, shape=(4, 5)))
        graph = Graph()

        # We'll use shape to distinguish inner/outer tensor
        subgraph_inp = Constant("input",
                                values=np.ones(dtype=np.float32, shape=(1, 2)))
        subgraph = Graph()
        subgraph.outputs = [subgraph.identity(subgraph_inp)]

        graph.outputs = [graph.nested(inp, subgraph)]

        graph_copy = graph.copy()
        assert graph_copy.nodes[0].attrs["body"].nodes[0].inputs[0].shape == (
            1, 2)
コード例 #2
0
ファイル: test_graph.py プロジェクト: phongphuhanam/TensorRT
    def test_shape_of_variable_tensor_multiple_shapes(self):
        graph = Graph()
        var = Variable("var", dtype=np.float32, shape=(1, 3, 4))
        var2 = Variable("var2", dtype=np.float32, shape=tuple())  # Scalar
        graph.inputs = [var, var2]
        graph.outputs = [
            graph.shape(var),
            graph.identity(var),
            graph.shape(var2)
        ]

        graph.fold_constants().cleanup()

        assert len(graph.nodes) == 1
        assert graph.nodes[0].op == "Identity"
        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == (1, 3, 4))
        assert isinstance(graph.outputs[2], Constant)
        assert np.all(graph.outputs[2].values == tuple())
コード例 #3
0
ファイル: test_graph.py プロジェクト: phongphuhanam/TensorRT
    def test_node_used_only_in_nested_graph(self):
        X = Variable("X", dtype=np.float32, shape=(1, ))
        Y = Variable("Y", dtype=np.float32, shape=(1, ))
        graph = Graph(inputs=[X, Y])

        X_p = graph.identity(
            X)  # X_p is only used by the subgraph, not in the outer graph.

        subgraph_inp = Variable("subgraph_input",
                                dtype=np.float32,
                                shape=(1, ))
        subgraph = Graph(inputs=[subgraph_inp])
        subgraph.outputs = [subgraph.add(subgraph_inp, X_p)]

        graph.outputs = [graph.nested(Y, subgraph)]

        graph.cleanup(remove_unused_graph_inputs=True)

        assert graph.nodes[0].op == "Identity"
        assert graph.nodes[0].inputs == [X]