def replace_pattern(graph: Graph, match: dict): node_ss = match['strided_slice'] # slices = [elem for elem in node_ss.slices if elem is not None] # node_ss.slices = np.array(slices) if node_ss.out_port(0).data.get_value() is not None: # StridedSlices(SS) in shape-calculating sub-graphs that should not be deleted that easily # Example: # In RetinaNetFilteredDetectionsReplacement we have SS that slices first batch # We delete such SS for batch 1, but it should be performed while reshaping the model return output_data_node = node_ss.out_node(0) input_data_node = node_ss.in_node(0) out_shape = output_data_node.shape if not np.all(node_ss.shrink_axis_mask == 0): out_shape = list(out_shape) for i in range(len(node_ss.shrink_axis_mask)): if node_ss.shrink_axis_mask[i] == 1: out_shape.insert(i, 1) out_shape = int64_array(out_shape) if not np.all(node_ss.new_axis_mask == 0): out_shape = list(out_shape) for i in reversed(range(len(node_ss.new_axis_mask))): if node_ss.new_axis_mask[i] == 1: out_shape.pop(i) out_shape = int64_array(out_shape) if np.array_equal(input_data_node.shape, out_shape) and \ all(elem.step == 1 for elem in match['strided_slice'].slices): if not np.all(node_ss.shrink_axis_mask == 0): ConvertGroupedStridedSlice.add_squeeze_for_shrink( graph, node_ss) if not np.all(node_ss.new_axis_mask == 0): ConvertGroupedStridedSlice.add_unsqueeze_for_new( graph, node_ss) log.info("Useless StridedSlice op '{}' has been detected".format( match['strided_slice'].id)) # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function graph.remove_edge(match['strided_slice'].in_node(1).id, match['strided_slice'].id) graph.remove_edge(match['strided_slice'].in_node(2).id, match['strided_slice'].id) if len(match['strided_slice'].in_nodes()) > 3: graph.remove_edge(match['strided_slice'].in_node(3).id, match['strided_slice'].id) remove_op_node_with_data_node(graph, match['strided_slice'])
def replace_pattern(graph: Graph, match: dict): node_ss = match['strided_slice'] output_data_node = node_ss.out_node(0) input_data_node = node_ss.in_node(0) out_shape = output_data_node.shape if not np.all(node_ss.shrink_axis_mask == 0): out_shape = list(out_shape) for i in range(len(node_ss.shrink_axis_mask)): if node_ss.shrink_axis_mask[i] == 1: out_shape.insert(i, 1) out_shape = int64_array(out_shape) if not np.all(node_ss.new_axis_mask == 0): out_shape = list(out_shape) for i in reversed(range(len(node_ss.new_axis_mask))): if node_ss.new_axis_mask[i] == 1: out_shape.pop(i) out_shape = int64_array(out_shape) if np.array_equal(input_data_node.shape, out_shape) and \ all(elem.step == 1 for elem in match['strided_slice'].slices): if not np.all(node_ss.shrink_axis_mask == 0): ConvertGroupedStridedSlice.add_squeeze_for_shrink( graph, node_ss) if not np.all(node_ss.new_axis_mask == 0): ConvertGroupedStridedSlice.add_unsqueeze_for_new( graph, node_ss) log.info("Useless StridedSlice op '{}' has been detected".format( match['strided_slice'].id)) # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function graph.remove_edge(match['strided_slice'].in_node(1).id, match['strided_slice'].id) graph.remove_edge(match['strided_slice'].in_node(2).id, match['strided_slice'].id) if len(match['strided_slice'].in_nodes()) > 3: graph.remove_edge(match['strided_slice'].in_node(3).id, match['strided_slice'].id) remove_op_node_with_data_node(graph, match['strided_slice'])
def test_8(self): graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'sslice_1'), ('sslice_1', 'sslice_1_data'), ('placeholder_1_data', 'sslice_2'), ('sslice_2', 'sslice_2_data'), ('sslice_1_data', 'concat_1'), ('sslice_2_data', 'concat_1'), ('concat_1', 'concat_1_data')], { 'placeholder_1_data': { 'shape': np.array([1, 54, 54, 3]) }, 'sslice_1': { 'slices': np.array([ slice(0, 1, 1), slice(0, 18, 1), slice(0, 54, 1), slice(0, 3, 1) ]) }, 'sslice_1_data': { 'shape': np.array([1, 18, 54, 3]) }, 'sslice_2': { 'slices': np.array([ slice(0, 1, 1), slice(18, 36, 1), slice(0, 54, 1), slice(0, 3, 1) ]) }, 'sslice_2_data': { 'shape': np.array([1, 18, 54, 3]) }, 'concat_1_data': { 'shape': np.array([1, 54, 54, 3]), 'is_output': True }, }) graph.graph['layout'] = 'NHWC' graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'split_1'), ('split_1', 'split_1_data'), ('split_1', 'split_2_data'), ('split_1', 'split_3_data'), ('split_1_data', 'concat_1'), ('split_3_data', 'concat_1'), ('concat_1', 'concat_1_data')], { 'placeholder_1_data': { 'shape': np.array([1, 54, 54, 3]) }, 'split_1': { 'axis': 1 }, 'split_1_data': { 'shape': np.array([1, 18, 54, 3]) }, 'split_2_data': { 'shape': np.array([1, 18, 54, 3]) }, 'split_3_data': { 'shape': np.array([1, 18, 54, 3]) }, 'concat_1_data': { 'shape': np.array([1, 54, 54, 3]), 'is_output': True }, }) pattern = ConvertGroupedStridedSlice() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True) self.assertTrue(flag, resp)