Пример #1
0
    def test_shape_gather(self, shape, indices):
        indices = np.array(indices)

        inp = Variable("input", dtype=np.float32, shape=shape)
        graph = Graph(inputs=[inp])

        inp_shape = graph.shape(inp)
        shape_part = graph.gather(inp_shape, indices=indices)
        graph.outputs = [
            graph.add(shape_part, shape_part),
            graph.gather(inp_shape, indices=[0]),
            graph.gather(inp_shape, indices=np.array(0)),
        ]

        graph.fold_constants()

        if shape is not None:
            assert isinstance(graph.outputs[0], Constant)
            expected_shape = np.array(shape)[indices].astype(np.int64) * 2
            assert np.all(graph.outputs[0].values == expected_shape)
        else:
            assert isinstance(graph.outputs[0], Variable)

        assert isinstance(graph.outputs[1], Variable)
        assert isinstance(graph.outputs[2], Variable)
Пример #2
0
    def test_fold_constants_one_hop(self):
        # Graph:
        # c = (a + b)
        # e = (c + d)
        # output = input + e
        # Should fold to:
        # output = input + e
        inp = Variable("input", shape=(1, 3), dtype=np.float32)
        a = Constant("a", values=np.ones(shape=(1, 3), dtype=np.float32))
        b = Constant("b", values=np.ones(shape=(1, 3), dtype=np.float32))
        c = Variable("c", shape=(1, 3), dtype=np.float32)
        d = Constant("d", values=np.ones(shape=(1, 3), dtype=np.float32))
        e = Variable("e", shape=(1, 3), dtype=np.float32)
        out = Variable("output", shape=(1, 3), dtype=np.float32)

        nodes = [
            Node("Add", inputs=[a, b], outputs=[c]),
            Node("Add", inputs=[c, d], outputs=[e]),
            Node("Add", inputs=[inp, e], outputs=[out]),
        ]

        graph = Graph(nodes=nodes, inputs=[inp], outputs=[out])

        graph.fold_constants().cleanup()

        # Extra nodes should be removed
        assert len(graph.nodes) == 1
        assert graph.nodes[0].inputs[0] == inp
        assert graph.nodes[0].inputs[1] == e
        # Value should be computed correctly
        assert np.all(graph.nodes[0].inputs[1].values == np.ones(shape=(1, 3), dtype=np.float32) * 3)
Пример #3
0
    def test_shape_of_variable_tensor_dynamic_shape(self):
        var = Variable("var", dtype=np.float32, shape=("", -1, 0, 4))
        graph = Graph(inputs=[var])
        graph.outputs = [graph.shape(var)]

        graph.fold_constants().cleanup()

        assert len(graph.nodes) == 1
        assert graph.nodes[0].op == "Shape"
        assert isinstance(graph.outputs[0], Variable)
Пример #4
0
    def test_shape_of_variable_tensor_static_shape_no_fold(self):
        graph = Graph()
        var = Variable("var", dtype=np.float32, shape=(1, 3, 4))
        graph.inputs = [var]
        graph.outputs = [graph.shape(var)]

        graph.fold_constants(fold_shapes=False).cleanup()

        assert len(graph.nodes) == 1
        assert graph.nodes[0].op == "Shape"
        assert isinstance(graph.outputs[0], Variable)
Пример #5
0
    def test_shape_of_variable_tensor_static_shape(self):
        var = Variable("var", dtype=np.float32, shape=(1, 3, 4))
        graph = Graph(inputs=[var])
        graph.inputs = [var]
        graph.outputs = [graph.shape(var)]

        graph.fold_constants().cleanup()

        assert not graph.nodes
        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == (1, 3, 4))
Пример #6
0
    def test_shape_of_constant_node(self):
        graph = Graph()
        values = np.ones((1, 3, 3), dtype=np.int64)
        const = graph.constant(values=values)
        graph.outputs = [graph.shape(const)]

        graph.fold_constants().cleanup()

        assert not graph.nodes
        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == (1, 3, 3))
Пример #7
0
    def test_const_node(self):
        graph = Graph()
        values = np.ones((1, 3, 3), dtype=np.int64)
        graph.outputs = [graph.constant(values=values)]

        assert isinstance(graph.outputs[0], Variable)

        graph.fold_constants().cleanup()

        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == values)
        assert not graph.nodes
Пример #8
0
    def test_no_foldable_constants(self):
        inp0 = Variable("input0", shape=(1, 3), dtype=np.float32)
        inp1 = Variable("input1", shape=(1, 3), dtype=np.float32)
        out = Variable("output", shape=(1, 3), dtype=np.float32)

        nodes = [Node("Add", inputs=[inp0, inp1], outputs=[out])]

        graph = Graph(nodes=nodes, inputs=[inp0, inp1], outputs=[out])

        graph.fold_constants().cleanup()

        assert len(graph.nodes) == 1
        assert graph.nodes[0].inputs == [inp0, inp1]
Пример #9
0
    def test_shape_of_variable_tensor_multiple_shapes(self):
        graph = Graph()
        var = Variable("var", dtype=np.float32, shape=(1, 3, 4))
        var2 = Variable("var2", dtype=np.float32, shape=tuple())  # Scalar
        graph.inputs = [var, var2]
        graph.outputs = [
            graph.shape(var),
            graph.identity(var),
            graph.shape(var2)
        ]

        graph.fold_constants().cleanup()

        assert len(graph.nodes) == 1
        assert graph.nodes[0].op == "Identity"
        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == (1, 3, 4))
        assert isinstance(graph.outputs[2], Constant)
        assert np.all(graph.outputs[2].values == tuple())
Пример #10
0
    def test_shape_slice_single_input(self):
        inp = Variable("input", dtype=np.int64, shape=(5, 6, 3, 2))
        graph = Graph(inputs=[inp])

        inp_shape = graph.shape(inp)
        graph.outputs = [graph.slice(inp_shape)]

        slice_node = graph.outputs[0].inputs[0]

        slice_node.attrs = {
            "axes": [0],
            "starts": [1],
            "ends": [3],
            "steps": [2],
        }

        graph.fold_constants()

        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == inp.shape[1:3:2])
Пример #11
0
    def test_shape_slice(self, shape, starts, ends, axes, steps, expected):
        inp = Variable("input", dtype=np.float32, shape=shape)
        graph = Graph(inputs=[inp])

        inp_shape = graph.shape(inp)
        graph.outputs = [
            graph.slice(inp_shape,
                        np.array(starts),
                        np.array(ends),
                        axes=np.array(axes),
                        steps=np.array(steps))
        ]

        graph.fold_constants()

        if expected:
            assert isinstance(graph.outputs[0], Constant)
            assert np.all(graph.outputs[0].values == expected)
        else:
            assert isinstance(graph.outputs[0], Variable)
Пример #12
0
    def test_const_inp_but_non_foldable_nested_graph(self):
        cond = gs.Constant("cond", values=np.array(True))
        X = gs.Variable("X", dtype=np.float32, shape=(1, ))

        graph = Graph(inputs=[X])

        then_graph = Graph(name="Then")
        then_graph.outputs = [then_graph.add(X, X)]

        else_graph = Graph(name="Else")
        else_graph.outputs = [else_graph.add(X, else_graph.add(X, X))]

        # Even though if_op looks foldable because it has all constant inputs,
        # it's not, since its subgraphs depend on variables in the outer scope.
        graph.outputs = [graph.if_op(cond, then_graph, else_graph)]

        # This should not raise because the `If` node should be excluded from
        # constant folding.
        graph.fold_constants(error_ok=False).cleanup()

        assert graph.nodes[0].op == "If"
        assert len(then_graph.nodes) == 1
        assert len(else_graph.nodes) == 2
Пример #13
0
    def test_with_nested_graph(self):
        cond = gs.Variable("cond", dtype=np.bool, shape=(1, ))

        X = gs.Variable("X", dtype=np.float32, shape=(1, ))
        Y = gs.Constant("Y", values=np.ones((1, ), dtype=np.float32))
        graph = Graph(inputs=[X, cond])

        then_graph = Graph(name="Then")
        then_graph.outputs = [then_graph.add(Y, Y)]

        else_graph = Graph(name="Else")
        else_graph.outputs = [else_graph.add(X, else_graph.add(Y, Y))]

        graph.outputs = [graph.if_op(cond, then_graph, else_graph)]

        graph.fold_constants()
        graph.cleanup()

        assert len(then_graph.nodes) == 0
        assert np.all(then_graph.outputs[0].values == (Y.values * 2))

        assert len(else_graph.nodes) == 1
        assert isinstance(else_graph.nodes[0].inputs[1], Constant)
        assert np.all(else_graph.nodes[0].inputs[1].values == (Y.values * 2))