def test_shape_slice_single_input(self): inp = Variable("input", dtype=np.int64, shape=(5, 6, 3, 2)) graph = Graph(inputs=[inp]) inp_shape = graph.shape(inp) graph.outputs = [graph.slice(inp_shape)] slice_node = graph.outputs[0].inputs[0] slice_node.attrs = { "axes": [0], "starts": [1], "ends": [3], "steps": [2], } graph.fold_constants() assert isinstance(graph.outputs[0], Constant) assert np.all(graph.outputs[0].values == inp.shape[1:3:2])
def test_shape_slice(self, shape, starts, ends, axes, steps, expected): inp = Variable("input", dtype=np.float32, shape=shape) graph = Graph(inputs=[inp]) inp_shape = graph.shape(inp) graph.outputs = [ graph.slice(inp_shape, np.array(starts), np.array(ends), axes=np.array(axes), steps=np.array(steps)) ] graph.fold_constants() if expected: assert isinstance(graph.outputs[0], Constant) assert np.all(graph.outputs[0].values == expected) else: assert isinstance(graph.outputs[0], Variable)