コード例 #1
0
ファイル: test_graph.py プロジェクト: phongphuhanam/TensorRT
    def test_const_inp_but_non_foldable_nested_graph(self):
        cond = gs.Constant("cond", values=np.array(True))
        X = gs.Variable("X", dtype=np.float32, shape=(1, ))

        graph = Graph(inputs=[X])

        then_graph = Graph(name="Then")
        then_graph.outputs = [then_graph.add(X, X)]

        else_graph = Graph(name="Else")
        else_graph.outputs = [else_graph.add(X, else_graph.add(X, X))]

        # Even though if_op looks foldable because it has all constant inputs,
        # it's not, since its subgraphs depend on variables in the outer scope.
        graph.outputs = [graph.if_op(cond, then_graph, else_graph)]

        # This should not raise because the `If` node should be excluded from
        # constant folding.
        graph.fold_constants(error_ok=False).cleanup()

        assert graph.nodes[0].op == "If"
        assert len(then_graph.nodes) == 1
        assert len(else_graph.nodes) == 2
コード例 #2
0
ファイル: test_graph.py プロジェクト: phongphuhanam/TensorRT
    def test_with_nested_graph(self):
        cond = gs.Variable("cond", dtype=np.bool, shape=(1, ))

        X = gs.Variable("X", dtype=np.float32, shape=(1, ))
        Y = gs.Constant("Y", values=np.ones((1, ), dtype=np.float32))
        graph = Graph(inputs=[X, cond])

        then_graph = Graph(name="Then")
        then_graph.outputs = [then_graph.add(Y, Y)]

        else_graph = Graph(name="Else")
        else_graph.outputs = [else_graph.add(X, else_graph.add(Y, Y))]

        graph.outputs = [graph.if_op(cond, then_graph, else_graph)]

        graph.fold_constants()
        graph.cleanup()

        assert len(then_graph.nodes) == 0
        assert np.all(then_graph.outputs[0].values == (Y.values * 2))

        assert len(else_graph.nodes) == 1
        assert isinstance(else_graph.nodes[0].inputs[1], Constant)
        assert np.all(else_graph.nodes[0].inputs[1].values == (Y.values * 2))