Beispiel #1
0
        def test_slice_infer(self, inp_value, starts, ends, axes, steps, expected, inp_shape=None):
            if inp_value is None:
                input_node = shaped_data('data_1', int64_array(inp_shape))
            else:
                input_node = valued_data('data_1', int64_array(inp_value))

            nodes = {
                **input_node,
                **regular_op_with_empty_data('slice', {'op': 'Slice'}),
                **valued_const_with_data('starts', int64_array(starts)),
                **valued_const_with_data('ends', int64_array(ends)),
                **valued_const_with_data('axes', int64_array(axes)),
                **valued_const_with_data('steps', int64_array(steps)),
            }

            graph = build_graph(nodes,
                                [('data_1', 'slice'),
                                 *connect('starts', '1:slice'),
                                 *connect('ends', '2:slice'),
                                 *connect('axes', '3:slice'),
                                 *connect('steps', '4:slice'),
                                 *connect('slice', 'slice_d')])

            graph.stage = 'middle'
            slice_node = Node(graph, 'slice')

            Slice.infer(slice_node)
            if inp_value is not None:
                self.assertTrue(np.array_equal(slice_node.out_node().value, expected))
            else:
                self.assertTrue(np.array_equal(slice_node.out_node().shape, expected))
Beispiel #2
0
        def test_slice_infer_negative(self, inp_value, starts, ends, axes, steps, expected, inp_shape=None):
            if inp_value is None:
                input_node = shaped_data('data_1', int64_array(inp_shape))
            else:
                input_node = valued_data('data_1', int64_array(inp_value))

            def convert_args(val, name=''):
                if val is not None:
                    return valued_const_with_data(name, int64_array(val))
                else:
                    return shaped_const_with_data(name, [0])  #fake shape

            starts = convert_args(starts, 'starts')
            ends = convert_args(ends, 'ends')
            axes = convert_args(axes, 'axes')
            steps = convert_args(steps, 'steps')

            nodes = { **input_node,
                      **regular_op_with_empty_data('slice', {'op': 'Slice'}),
                      **starts, **ends, **axes, **steps }

            graph = build_graph(nodes,
                                [('data_1', 'slice'),
                                 *connect('starts', '1:slice'),
                                 *connect('ends', '2:slice'),
                                 *connect('axes', '3:slice'),
                                 *connect('steps', '4:slice'),
                                 *connect('slice', 'slice_d')])

            graph.stage = 'middle'
            slice_node = Node(graph, 'slice')
            self.assertRaises(Error, Slice.infer, slice_node)