def test_non_const_infer(self):
        # Testing constant path case
        graph = build_graph(nodes_attributes,
                            [('input', 'data_1'),
                             ('data_1', 'strided_slice', {'in': 0}),
                             ('data_1', 'strided_slice', {'in': 1}),
                             ('end', 'end_data'),
                             ('end_data', 'strided_slice', {'in': 2}),
                             ('stride', 'stride_data'),
                             ('stride_data', 'strided_slice', {'in': 3}),
                             ('strided_slice', 'data_2')],
                            {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             'end': {'value': [1, 0], 'shape': [2]},
                             'stride': {'value': [1, 2], 'shape': [2]},
                             'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]),
                                               'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0],
                                               'ellipsis_mask': np.array([1, 0])},
                             'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             })
        graph.graph['layout'] = "NHWC"

        slice_node = Node(graph, 'strided_slice')
        with self.assertRaises(Error) as error:
            StridedSlice.infer(slice_node)
        self.assertTrue('Strided slice layer supports only constant begin and end inputs' in str(error.exception))
    def test_ss_shrink_only_short(self):
        graph = build_graph(nodes_attributes,
                            [('input', 'data_1'),
                             ('data_1', 'strided_slice', {'in': 0}),
                             ('begin', 'begin_data'),
                             ('begin_data', 'strided_slice', {'in': 1}),
                             ('end', 'end_data'),
                             ('end_data', 'strided_slice', {'in': 2}),
                             ('stride', 'stride_data'),
                             ('stride_data', 'strided_slice', {'in': 3}),
                             ('strided_slice', 'data_2')],
                            {'data_1': {'shape': np.array([1, 1, 127, 54]), 'value': None},
                             'begin': {'value': [0, 0, 0], 'shape': [3]},
                             'end': {'value': [0, 0, 0], 'shape': [3]},
                             'stride': {'value': [1, 1, 1], 'shape': [3]},
                             'begin_data': {'value': [0, 0, 0], 'shape': [3]},
                             'end_data': {'value': [0, 0, 0], 'shape': [3]},
                             'stride_data': {'value': [1, 1, 1], 'shape': [3]},
                             'strided_slice': {'begin_mask': np.array([0, 0, 0]), 'end_mask': np.array([0, 0, 0]),
                                               'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [0, 1, 0],
                                               'ellipsis_mask': np.array([0, 0, 0])},
                             'data_2': {'shape': None}
                             }, nodes_with_edges_only=True)
        graph.graph['layout'] = 'NCHW'

        slice_node = Node(graph, 'strided_slice')
        begin_node = Node(graph, 'begin')
        end_node = Node(graph, 'end')
        stride_node = Node(graph, 'stride')
        out_node = Node(graph, 'data_2')
        StridedSlice.infer(slice_node)
        self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(slice_node.shrink_axis_mask, np.array([0, 1, 0, 0])))
        self.assertTrue(np.array_equal(slice_node.new_axis_mask, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(slice_node.ellipsis_mask, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(begin_node.value, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(end_node.value, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(stride_node.value, np.array([1, 1, 1, 1])))
        self.assertTrue(np.array_equal(out_node.shape, np.array([1, 127, 54])))
    def test_permute_begin_end_ellipsis_infer(self):
        # Testing constant path case
        graph = build_graph(nodes_attributes,
                            [('input', 'data_1'),
                             ('data_1', 'strided_slice', {'in': 0}),
                             ('begin', 'begin_data'),
                             ('begin_data', 'strided_slice', {'in': 1}),
                             ('end', 'end_data'),
                             ('end_data', 'strided_slice', {'in': 2}),
                             ('stride', 'stride_data'),
                             ('stride_data', 'strided_slice', {'in': 3}),
                             ('strided_slice', 'data_2')],
                            {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             'begin': {'value': [0, 1], 'shape': [2]},
                             'end': {'value': [1, 0], 'shape': [2]},
                             'stride': {'value': [1, 2], 'shape': [2]},
                             'begin_data': {'value': [0, 1], 'shape': [2]},
                             'end_data': {'value': [1, 0], 'shape': [2]},
                             'stride_data': {'value': [1, 2], 'shape': [2]},
                             'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]),
                                               'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0],
                                               'ellipsis_mask': np.array([1, 0])},
                             'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
                             })
        graph.graph['layout'] = "NHWC"

        slice_node = Node(graph, 'strided_slice')
        begin_node = Node(graph, 'begin')
        end_node = Node(graph, 'end')
        stride_node = Node(graph, 'stride')
        StridedSlice.infer(slice_node)
        self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 0])))
        self.assertTrue(np.array_equal(slice_node.shrink_axis_mask, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(slice_node.new_axis_mask, np.array([0, 0, 0, 0])))
        self.assertTrue(np.array_equal(begin_node.value, np.array([0, 1, 0, 0])))
        self.assertTrue(np.array_equal(end_node.value, np.array([1, 0, 0, 0])))
        self.assertTrue(np.array_equal(stride_node.value, np.array([1, 2, 1, 1])))