def test_no_steps_no_axes(self): input_shape = int64_array([5, 10, 20]) starts_value = int64_array([3, 2, 7]) ends_value = int64_array([5, 8, 15]) steps_value = int64_array([1, 1, 1]) masks_value = np.zeros([len(input_shape)], dtype=np.int64) graph = build_graph(self.nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice', {'in': 0}), ('starts', 'starts_data'), ('starts_data', 'slice', {'in': 1}), ('ends', 'ends_data'), ('ends_data', 'slice', {'in': 2}), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output') ], {'placeholder_1_data': {'shape': input_shape}, 'starts': {'shape': starts_value.shape, 'value': starts_value}, 'starts_data': {'shape': starts_value.shape, 'value': starts_value}, 'ends': {'shape': ends_value.shape, 'value': ends_value}, 'ends_data': {'shape': ends_value.shape, 'value': ends_value}, }, nodes_with_edges_only=True ) slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node' graph_ref = build_graph(self.nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'strided_slice', {'in': 0}), ('starts', 'starts_data'), ('starts_data', 'strided_slice', {'in': 1}), ('ends', 'ends_data'), ('ends_data', 'strided_slice', {'in': 2}), ('strides', 'strides_data'), ('strides_data', 'strided_slice', {'in': 3}), ('strided_slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output') ], {'placeholder_1_data': {'shape': input_shape}, 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value, 'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]), 'end_mask': np.ones([3])}, 'slice_data': {'shape': int64_array([2, 6, 8])} }, nodes_with_edges_only=True ) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_1(self): """ Testing case with non-constant path and multiple slicing dimensions :return: """ graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice'), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'slice': { 'start': np.array([1, 2, 3]), 'end': np.array([3, 4, 4]), 'axis': None }, }, nodes_with_edges_only=True, ) slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) graph.clean_up() ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) assert ss_node.type == 'Crop', 'Something wrong with transformed Slice node' graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'crop'), ('crop', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'crop': { 'axis': np.array([0, 1, 2]), 'offset': np.array([1, 2, 3]), 'dim': np.array([2, 2, 1]) }, }, nodes_with_edges_only=True, ) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_2(self): """ Testing case with constant path and one slicing dimension """ graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice'), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'slice': { 'start': np.array([1]), 'end': np.array([3]), 'axis': None }, 'output_op': { 'is_output': True } }) slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'strided_slice'), ('strided_slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'strided_slice': { 'slices': np.array([slice(1, 3, 1), slice(0, 5, 1), slice(0, 6, 1)]), 'shrink_axis_mask': np.array([False, False, False]) }, 'output_op': { 'is_output': True } }) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_1(self): """ Testing case with non-constant path and multiple slicing dimensions :return: """ graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice'), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'slice': { 'start': np.array([1, 2, 3]), 'end': np.array([3, 4, 4]), 'axis': None }, 'output_op': { 'is_output': True }, }) slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'crop'), ('crop', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'crop': { 'axis': np.array([0, 1, 2]), 'offset': np.array([1, 2, 3]), }, 'output_op': { 'is_output': True }, 'dim': { 'dim': np.array([2, 2, 1]) }, }) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_convert_slice_to_strided_slice_without_axes_and_steps(self): graph = build_graph( nodes_attrs=nodes_attributes, edges=[ *connect('input:0', '0:slice'), *connect('starts:0', '1:slice'), *connect('ends:0', '2:slice'), *connect('slice:0', '0:result') ], update_attributes={ 'starts': {'value': int64_array([0, 0, 0, 0]), 'shape': [4]}, 'ends': {'value': int64_array([1, 2, 150, 150]), 'shape': [4]}, }, nodes_with_edges_only=True ) ref_graph = build_graph( nodes_attrs=nodes_attributes, edges=pattern_ref_graph + [ *connect('ss_begin_cast:0', '0:ss_begin_gather_0'), *connect('ss_begin_gather_0:0', '0:ss_begin_concat'), *connect_data('ss_begin_cast:0', '0:ss_begin_gather_1'), *connect('ss_begin_gather_1:0', '1:ss_begin_concat'), *connect_data('ss_begin_cast:0', '0:ss_begin_gather_2'), *connect('ss_begin_gather_2:0', '2:ss_begin_concat'), *connect_data('ss_begin_cast:0', '0:ss_begin_gather_3'), *connect('ss_begin_gather_3:0', '3:ss_begin_concat'), *connect('ss_end_cast:0', '0:ss_end_gather_0'), *connect('ss_end_gather_0:0', '0:ss_end_concat'), *connect_data('ss_end_cast:0', '0:ss_end_gather_1'), *connect('ss_end_gather_1:0', '1:ss_end_concat'), *connect_data('ss_end_cast:0', '0:ss_end_gather_2'), *connect('ss_end_gather_2:0', '2:ss_end_concat'), *connect_data('ss_end_cast:0', '0:ss_end_gather_3'), *connect('ss_end_gather_3:0', '3:ss_end_concat'), ], update_attributes={ 'starts': {'value': int64_array([0, 0, 0, 0]), 'shape': [4]}, 'ends': {'value': int64_array([1, 2, 150, 150]), 'shape': [4]}, 'ss_strides': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]}, 'ss': {'begin_mask': int64_array([1, 1, 1, 1]), 'end_mask': int64_array([1, 1, 1, 1])} } ) ConvertSlice().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_3(self): """ Testing case with constant path and one slicing dimension """ graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice'), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([1, 5, 6]) }, 'slice': { 'start': np.array([1]), 'end': np.array([3]), 'axis': np.array([1]) } }, nodes_with_edges_only=True, ) graph.graph['layout'] = 'NHWC' slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) graph.clean_up() ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node' graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_2', 'placeholder_2_data'), ('placeholder_3', 'placeholder_3_data'), ('placeholder_1_data', 'strided_slice'), ('placeholder_2_data', 'strided_slice'), ('placeholder_3_data', 'strided_slice'), ('strided_slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([1, 5, 6]) }, 'strided_slice': { 'slices': np.array([slice(0, 1, 1), slice(1, 3, 1), slice(0, 6, 1)]), 'shrink_axis_mask': np.array([False, False, False]) }, }, nodes_with_edges_only=True, ) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_convert_slice_to_strided_slice(self, input_shape, start, end, axes, steps, ss_begin_parts: tuple, ss_end_parts: tuple, ss_steps, ss_begin_mask, ss_end_mask): graph = build_graph( nodes_attrs={ **regular_op_with_shaped_data('input', input_shape, { 'type': 'Parameter' }), **valued_const_with_data('start', start), **valued_const_with_data('end', end), **valued_const_with_data('axes', axes), **valued_const_with_data('steps', steps), **regular_op_with_empty_data('slice', { 'type': None, 'op': 'Slice' }), **result('result') }, edges=[ *connect('input', 'slice'), *connect('start', '1:slice'), *connect('end', '2:slice'), *connect('axes', '3:slice'), *connect('steps', '4:slice'), *connect('slice', 'result') ]) ref_graph = build_graph(nodes_attrs={ **regular_op_with_shaped_data('input', input_shape, { 'type': 'Parameter' }), **valued_const_with_data('start', start), **valued_const_with_data('begin_first_part', ss_begin_parts[0]), **valued_const_with_data('begin_last_part', ss_begin_parts[1]), **regular_op_with_empty_data('convert_start', { 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64 }), **regular_op_with_empty_data('ss_begin', { 'type': 'Concat', 'op': 'Concat', 'axis': 0 }), **valued_const_with_data('end', end), **valued_const_with_data('end_first_part', ss_end_parts[0]), **valued_const_with_data('end_last_part', ss_end_parts[1]), **regular_op_with_empty_data('convert_end', { 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64 }), **regular_op_with_empty_data('ss_end', { 'type': 'Concat', 'op': 'Concat', 'axis': 0 }), **const('ss_steps', ss_steps), **empty_data('ss_steps_d'), **regular_op_with_empty_data( 'ss', { 'op': 'StridedSlice', 'type': 'StridedSlice', 'begin_mask': ss_begin_mask, 'end_mask': ss_end_mask, 'new_axis_mask': np.zeros(len(input_shape), dtype=np.int64), 'shrink_axis_mask': np.zeros(len(input_shape), dtype=np.int64), 'ellipsis_mask': np.zeros(len(input_shape), dtype=np.int64) }), **result('result') }, edges=[ *connect('input', 'ss'), *connect('begin_first_part', 'ss_begin'), *connect('start', 'convert_start'), *connect('convert_start', '1:ss_begin'), *connect('begin_last_part', '2:ss_begin'), *connect('ss_begin', '1:ss'), *connect('end_first_part', 'ss_end'), *connect('end', 'convert_end'), *connect('convert_end', '1:ss_end'), *connect('end_last_part', '2:ss_end'), *connect('ss_end', '2:ss'), *connect('ss_steps', '3:ss'), *connect('ss', 'result') ]) ConvertSlice().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_convert_slice_to_strided_slice_without_axes_and_steps(self): graph = build_graph(nodes_attrs={ **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), { 'type': 'Parameter' }), **valued_const_with_data('start', np.array([0, 0, 0])), **valued_const_with_data('end', np.array([1, 3, 5])), **regular_op_with_empty_data('slice', { 'type': None, 'op': 'Slice' }), **result('result') }, edges=[ *connect('input', 'slice'), *connect('start', '1:slice'), *connect('end', '2:slice'), *connect('slice', 'result') ]) ref_graph = build_graph(nodes_attrs={ **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), { 'type': 'Parameter' }), **valued_const_with_data('start', np.array([0, 0, 0])), **valued_const_with_data('begin_first_part', int64_array([])), **valued_const_with_data('begin_last_part', int64_array([])), **regular_op_with_empty_data('convert_start', { 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64 }), **regular_op_with_empty_data('ss_begin', { 'type': 'Concat', 'op': 'Concat', 'axis': 0 }), **valued_const_with_data('end', np.array([1, 3, 5])), **valued_const_with_data('end_first_part', int64_array([])), **valued_const_with_data('end_last_part', int64_array([])), **regular_op_with_empty_data('convert_end', { 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64 }), **regular_op_with_empty_data('ss_end', { 'type': 'Concat', 'op': 'Concat', 'axis': 0 }), **const('ss_steps', int64_array([1, 1, 1])), **empty_data('ss_steps_d'), **regular_op_with_empty_data( 'ss', { 'op': 'StridedSlice', 'type': 'StridedSlice', 'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 1, 1]), 'new_axis_mask': np.zeros(3, dtype=np.int64), 'shrink_axis_mask': np.zeros(3, dtype=np.int64), 'ellipsis_mask': np.zeros(3, dtype=np.int64) }), **result('result') }, edges=[ *connect('input', 'ss'), *connect('begin_first_part', 'ss_begin'), *connect('start', 'convert_start'), *connect('convert_start', '1:ss_begin'), *connect('begin_last_part', '2:ss_begin'), *connect('ss_begin', '1:ss'), *connect('end_first_part', 'ss_end'), *connect('end', 'convert_end'), *connect('convert_end', '1:ss_end'), *connect('end_last_part', '2:ss_end'), *connect('ss_end', '2:ss'), *connect('ss_steps', '3:ss'), *connect('ss', 'result') ]) ConvertSlice().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp)