예제 #1
0
    def test_shape_of_constant_node(self):
        graph = Graph()
        values = np.ones((1, 3, 3), dtype=np.int64)
        const = graph.constant(values=values)
        graph.outputs = [graph.shape(const)]

        graph.fold_constants().cleanup()

        assert not graph.nodes
        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == (1, 3, 3))
예제 #2
0
    def test_const_node(self):
        graph = Graph()
        values = np.ones((1, 3, 3), dtype=np.int64)
        graph.outputs = [graph.constant(values=values)]

        assert isinstance(graph.outputs[0], Variable)

        graph.fold_constants().cleanup()

        assert isinstance(graph.outputs[0], Constant)
        assert np.all(graph.outputs[0].values == values)
        assert not graph.nodes