def test_non_const_infer(self): # Testing constant path case graph = build_graph(nodes_attributes, [('input', 'data_1'), ('data_1', 'strided_slice', {'in': 0}), ('data_1', 'strided_slice', {'in': 1}), ('end', 'end_data'), ('end_data', 'strided_slice', {'in': 2}), ('stride', 'stride_data'), ('stride_data', 'strided_slice', {'in': 3}), ('strided_slice', 'data_2')], {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None}, 'end': {'value': [1, 0], 'shape': [2]}, 'stride': {'value': [1, 2], 'shape': [2]}, 'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0], 'ellipsis_mask': np.array([1, 0])}, 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None}, }) graph.graph['layout'] = "NHWC" slice_node = Node(graph, 'strided_slice') with self.assertRaises(Error) as error: StridedSlice.infer(slice_node) self.assertTrue('Strided slice layer supports only constant begin and end inputs' in str(error.exception))
def test_ss_shrink_only_short(self): graph = build_graph(nodes_attributes, [('input', 'data_1'), ('data_1', 'strided_slice', {'in': 0}), ('begin', 'begin_data'), ('begin_data', 'strided_slice', {'in': 1}), ('end', 'end_data'), ('end_data', 'strided_slice', {'in': 2}), ('stride', 'stride_data'), ('stride_data', 'strided_slice', {'in': 3}), ('strided_slice', 'data_2')], {'data_1': {'shape': np.array([1, 1, 127, 54]), 'value': None}, 'begin': {'value': [0, 0, 0], 'shape': [3]}, 'end': {'value': [0, 0, 0], 'shape': [3]}, 'stride': {'value': [1, 1, 1], 'shape': [3]}, 'begin_data': {'value': [0, 0, 0], 'shape': [3]}, 'end_data': {'value': [0, 0, 0], 'shape': [3]}, 'stride_data': {'value': [1, 1, 1], 'shape': [3]}, 'strided_slice': {'begin_mask': np.array([0, 0, 0]), 'end_mask': np.array([0, 0, 0]), 'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [0, 1, 0], 'ellipsis_mask': np.array([0, 0, 0])}, 'data_2': {'shape': None} }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' slice_node = Node(graph, 'strided_slice') begin_node = Node(graph, 'begin') end_node = Node(graph, 'end') stride_node = Node(graph, 'stride') out_node = Node(graph, 'data_2') StridedSlice.infer(slice_node) self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(slice_node.shrink_axis_mask, np.array([0, 1, 0, 0]))) self.assertTrue(np.array_equal(slice_node.new_axis_mask, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(slice_node.ellipsis_mask, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(begin_node.value, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(end_node.value, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(stride_node.value, np.array([1, 1, 1, 1]))) self.assertTrue(np.array_equal(out_node.shape, np.array([1, 127, 54])))
def test_permute_begin_end_ellipsis_infer(self): # Testing constant path case graph = build_graph(nodes_attributes, [('input', 'data_1'), ('data_1', 'strided_slice', {'in': 0}), ('begin', 'begin_data'), ('begin_data', 'strided_slice', {'in': 1}), ('end', 'end_data'), ('end_data', 'strided_slice', {'in': 2}), ('stride', 'stride_data'), ('stride_data', 'strided_slice', {'in': 3}), ('strided_slice', 'data_2')], {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None}, 'begin': {'value': [0, 1], 'shape': [2]}, 'end': {'value': [1, 0], 'shape': [2]}, 'stride': {'value': [1, 2], 'shape': [2]}, 'begin_data': {'value': [0, 1], 'shape': [2]}, 'end_data': {'value': [1, 0], 'shape': [2]}, 'stride_data': {'value': [1, 2], 'shape': [2]}, 'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0], 'ellipsis_mask': np.array([1, 0])}, 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None}, }) graph.graph['layout'] = "NHWC" slice_node = Node(graph, 'strided_slice') begin_node = Node(graph, 'begin') end_node = Node(graph, 'end') stride_node = Node(graph, 'stride') StridedSlice.infer(slice_node) self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 0]))) self.assertTrue(np.array_equal(slice_node.shrink_axis_mask, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(slice_node.new_axis_mask, np.array([0, 0, 0, 0]))) self.assertTrue(np.array_equal(begin_node.value, np.array([0, 1, 0, 0]))) self.assertTrue(np.array_equal(end_node.value, np.array([1, 0, 0, 0]))) self.assertTrue(np.array_equal(stride_node.value, np.array([1, 2, 1, 1])))