Beispiel #1
0
    def test_input_is_output(self):
        graph = Graph()

        A = Variable("A", dtype=np.float32, shape=(1, 1))
        B = Variable("B", dtype=np.float32, shape=(1, 1))

        C = graph.add(A, B)

        graph.inputs = [A, B]
        graph.outputs = [C, B, A]  # Out of order w/ respect to Add node inputs

        # Graph should remain unchanged after cleanup, including I/O tensors.
        graph.cleanup()

        assert graph.inputs == [A, B]
        assert graph.outputs == [C, B, A]
        assert len(graph.nodes) == 1
        assert graph.nodes[0].inputs == [A, B]
        assert graph.nodes[0].outputs == [C]
Beispiel #2
0
    def test_node_used_only_in_nested_graph(self):
        X = Variable("X", dtype=np.float32, shape=(1, ))
        Y = Variable("Y", dtype=np.float32, shape=(1, ))
        graph = Graph(inputs=[X, Y])

        X_p = graph.identity(
            X)  # X_p is only used by the subgraph, not in the outer graph.

        subgraph_inp = Variable("subgraph_input",
                                dtype=np.float32,
                                shape=(1, ))
        subgraph = Graph(inputs=[subgraph_inp])
        subgraph.outputs = [subgraph.add(subgraph_inp, X_p)]

        graph.outputs = [graph.nested(Y, subgraph)]

        graph.cleanup(remove_unused_graph_inputs=True)

        assert graph.nodes[0].op == "Identity"
        assert graph.nodes[0].inputs == [X]
Beispiel #3
0
    def test_with_nested_graph(self):
        cond = gs.Variable("cond", dtype=np.bool, shape=(1, ))

        X = gs.Variable("X", dtype=np.float32, shape=(1, ))
        Y = gs.Constant("Y", values=np.ones((1, ), dtype=np.float32))
        graph = Graph(inputs=[X, cond])

        then_graph = Graph(name="Then")
        then_graph.outputs = [then_graph.add(Y, Y)]

        else_graph = Graph(name="Else")
        else_graph.outputs = [else_graph.add(X, else_graph.add(Y, Y))]

        graph.outputs = [graph.if_op(cond, then_graph, else_graph)]

        graph.fold_constants()
        graph.cleanup()

        assert len(then_graph.nodes) == 0
        assert np.all(then_graph.outputs[0].values == (Y.values * 2))

        assert len(else_graph.nodes) == 1
        assert isinstance(else_graph.nodes[0].inputs[1], Constant)
        assert np.all(else_graph.nodes[0].inputs[1].values == (Y.values * 2))