Exemple #1
0
    def test_shape_gather(self, shape, indices):
        indices = np.array(indices)

        inp = Variable("input", dtype=np.float32, shape=shape)
        graph = Graph(inputs=[inp])

        inp_shape = graph.shape(inp)
        shape_part = graph.gather(inp_shape, indices=indices)
        graph.outputs = [
            graph.add(shape_part, shape_part),
            graph.gather(inp_shape, indices=[0]),
            graph.gather(inp_shape, indices=np.array(0)),
        ]

        graph.fold_constants()

        if shape is not None:
            assert isinstance(graph.outputs[0], Constant)
            expected_shape = np.array(shape)[indices].astype(np.int64) * 2
            assert np.all(graph.outputs[0].values == expected_shape)
        else:
            assert isinstance(graph.outputs[0], Variable)

        assert isinstance(graph.outputs[1], Variable)
        assert isinstance(graph.outputs[2], Variable)
Exemple #2
0
    def test_shape_of_variable_tensor_multiple_shapes(self):
        graph = Graph()
        var = Variable("var", dtype=np.float32, shape=(1, 3, 4))
        var2 = Variable("var2", dtype=np.float32, shape=tuple())  # Scalar
        graph.inputs = [var, var2]
        graph.outputs = [
            graph.shape(var),
            graph.identity(var),
            graph.shape(var2)
        ]

        graph.fold_constants().cleanup()

        assert len(graph.nodes) == 1
        assert graph.nodes[0].op == "Identity"
        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == (1, 3, 4))
        assert isinstance(graph.outputs[2], Constant)
        assert np.all(graph.outputs[2].values == tuple())
Exemple #3
0
    def test_shape_of_variable_tensor_dynamic_shape(self):
        var = Variable("var", dtype=np.float32, shape=("", -1, 0, 4))
        graph = Graph(inputs=[var])
        graph.outputs = [graph.shape(var)]

        graph.fold_constants().cleanup()

        assert len(graph.nodes) == 1
        assert graph.nodes[0].op == "Shape"
        assert isinstance(graph.outputs[0], Variable)
Exemple #4
0
    def test_shape_of_variable_tensor_static_shape_no_fold(self):
        graph = Graph()
        var = Variable("var", dtype=np.float32, shape=(1, 3, 4))
        graph.inputs = [var]
        graph.outputs = [graph.shape(var)]

        graph.fold_constants(fold_shapes=False).cleanup()

        assert len(graph.nodes) == 1
        assert graph.nodes[0].op == "Shape"
        assert isinstance(graph.outputs[0], Variable)
Exemple #5
0
    def test_shape_of_variable_tensor_static_shape(self):
        var = Variable("var", dtype=np.float32, shape=(1, 3, 4))
        graph = Graph(inputs=[var])
        graph.inputs = [var]
        graph.outputs = [graph.shape(var)]

        graph.fold_constants().cleanup()

        assert not graph.nodes
        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == (1, 3, 4))
Exemple #6
0
    def test_shape_of_constant_node(self):
        graph = Graph()
        values = np.ones((1, 3, 3), dtype=np.int64)
        const = graph.constant(values=values)
        graph.outputs = [graph.shape(const)]

        graph.fold_constants().cleanup()

        assert not graph.nodes
        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == (1, 3, 3))
Exemple #7
0
    def test_shape_slice_single_input(self):
        inp = Variable("input", dtype=np.int64, shape=(5, 6, 3, 2))
        graph = Graph(inputs=[inp])

        inp_shape = graph.shape(inp)
        graph.outputs = [graph.slice(inp_shape)]

        slice_node = graph.outputs[0].inputs[0]

        slice_node.attrs = {
            "axes": [0],
            "starts": [1],
            "ends": [3],
            "steps": [2],
        }

        graph.fold_constants()

        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == inp.shape[1:3:2])
Exemple #8
0
    def test_shape_slice(self, shape, starts, ends, axes, steps, expected):
        inp = Variable("input", dtype=np.float32, shape=shape)
        graph = Graph(inputs=[inp])

        inp_shape = graph.shape(inp)
        graph.outputs = [
            graph.slice(inp_shape,
                        np.array(starts),
                        np.array(ends),
                        axes=np.array(axes),
                        steps=np.array(steps))
        ]

        graph.fold_constants()

        if expected:
            assert isinstance(graph.outputs[0], Constant)
            assert np.all(graph.outputs[0].values == expected)
        else:
            assert isinstance(graph.outputs[0], Variable)