def find_and_replace_pattern(self, graph: Graph): for node in graph.get_data_nodes(): if node.has_and_set('nchw_layout'): continue # Get NHWC to NCHW permutation for N dims, where N = len(node.shape) permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(len(node.shape)) # Check that data node already has permutation skip_permutation = False for in_node in node.in_nodes(): edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True for out_node in node.out_nodes(): edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True if skip_permutation: continue # Set permutation to all in/out edges for in_node in node.in_nodes(): PermuteAttrs.set_permutation(in_node, node, permutation) for out_node in node.out_nodes(): PermuteAttrs.set_permutation(node, out_node, permutation)
def permute_data_nodes_attrs(graph: Graph): # Iterate over all data nodes and apply permutation if exists for node in graph.get_data_nodes(): if not node.has_valid('permutation') or \ all([attrs.get('input_permutation', False) for u, v, attrs in graph.out_edges(node.id, data=True)]): continue if len( node.in_nodes() ) != 0: # there are data nodes without input operation node inside the TensorIterator edge_attrs = graph.get_edge_data(node.in_node(0).id, node.id)[0] if is_output_data_in_correct_layout(node.in_node(0), edge_attrs['out']): log.debug( 'Do not permute data node attrs for node "{}" output port "{}"' .format(node.in_node(0).id, edge_attrs['out'])) continue # Apply permutation for shape and value if exists if len(node.permutation.perm) == 0: continue node.shape = shape_array(node.shape)[node.permutation.perm] if node.has_valid('value'): assert len(node.value.shape) == len(node.permutation.perm), \ 'Node {} has shape {} and permutation {} that does not match. Their lengths should be equal' \ ''.format(node.name, node.value.shape, node.permutation.perm) node.value = mo_array( node.value.transpose(node.permutation.perm))
def convert_blobs(graph: Graph, data_type_str: str): for node in graph.get_data_nodes(): if node.value is not None: try: if node.value.dtype in [ np.float32, np.float64, np.float16 ] and not node.has_and_set('correct_data_type'): convert_node_blobs(graph, node, data_type_str_to_np(data_type_str)) except Exception as e: raise Error('Coudn\'t convert blob {}, details: {}', node.soft_get('name'), e) from e
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(): if node.soft_get('type').lower() not in OpVersioning.opset_1_types and \ not node.soft_get('version') in ["opset2", "opset3", "opset4", "opset8"]: continue for _, d in node.in_edges().items(): if 'bin' in d: del d['bin'] for node in graph.get_data_nodes(): for d in node.in_edges(): if 'bin' in d: del d['bin']
def find_and_replace_pattern(self, graph: Graph): # Iterate over all data nodes and find all with >= 1 consumers for input_data in list(graph.get_data_nodes()): # We don't use constant data nodes if input_data.value is not None: continue if input_data.shape is None: continue input_shape = shape_array(input_data.shape) # Get all unique StridedSlice consumers out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).id == input_data.id] if len(out_nodes) <= 1: continue valid_for_replacement = True for n in out_nodes: if any(not isinstance(s, slice) for s in n.slices): # this is a slice with dynamic dimension. Such operation is not valid for replacement valid_for_replacement = False if not valid_for_replacement: continue sorted_out_nodes = sorted(out_nodes, key=lambda n: list(n.slices)) out_nodes = unique_by(sorted_out_nodes, strided_slices_equality) for node in out_nodes: if len(node.slices) != len(out_nodes[0].slices): valid_for_replacement = False # Detect dimension for splitting split_channel_dim = None for dim_id, s in enumerate(out_nodes[0].slices): l, r, stride = s.start, s.stop, s.step # if both l and r are None then the dimension is not sliced if (l != 0 or r != input_shape[dim_id]) and (l is not None or r is not None): if split_channel_dim is None: split_channel_dim = dim_id else: valid_for_replacement = False if split_channel_dim is None: valid_for_replacement = False # split_dims contains tuples with split range and output data node split_dims = [] for out_id, node in enumerate(out_nodes): # Check that StridedSlice op has stride eq 1 and splits only feature channel for id, s in enumerate(node.slices): l, r, stride = s.start, s.stop, s.step # We don't support StridedSlice with stride != 1 if stride != 1: valid_for_replacement = False if id == split_channel_dim: split_dims.append((s.start, s.stop, node.out_node())) if not valid_for_replacement: continue # Check feature split intersection final_data_nodes_list = [] sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1])) # check if we have similar StridedSlice operations with different outputs prev_sd = sorted_split_dims[0] to_remove = [] for i in range(1, len(sorted_split_dims)): if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name: cur_node = sorted_split_dims[i][2] for out in cur_node.out_nodes(): attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0]) graph.remove_edge(cur_node.id, out.id) graph.add_edge(prev_sd[2].id, out.id, **attrs) to_remove.append(i) for ind in reversed(to_remove): sorted_split_dims.pop(ind) size_splits = [] prev_r = 0 for l, r, out in sorted_split_dims: # Split dims shouldn't intersect if l < prev_r: valid_for_replacement = False prev_r = r if prev_r > input_shape[split_channel_dim]: valid_for_replacement = False if not valid_for_replacement: continue prev_r = 0 for l, r, out in sorted_split_dims: # Save missing tensor part if l > prev_r: shape = mo_array(input_shape) size_splits.append(l - prev_r) shape[split_channel_dim] = l - prev_r data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape}) add_opoutput(graph, data_node.id, 0, False, keep_output_port=True) final_data_nodes_list.append(data_node) prev_r = r size_splits.append(r - l) final_data_nodes_list.append(out) if prev_r < input_shape[split_channel_dim]: # Add last part of tensor shape = input_shape.copy() shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r size_splits.append(input_shape[split_channel_dim] - prev_r) data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape}) add_opoutput(graph, data_node.id, 0, False, keep_output_port=True) final_data_nodes_list.append(data_node) for node in out_nodes: if not np.all([x == 0 for x in node.shrink_axis_mask]): out_node = node.out_node() if np.any(node['shrink_axis_mask']): self.add_squeeze_for_shrink(graph, node) if np.any(node['new_axis_mask']): self.add_unsqueeze_for_new(graph, node) for i in range(len(final_data_nodes_list)): if final_data_nodes_list[i].name == out_node.name: final_data_nodes_list[i] = node.out_node() break # Insert Split layer and remove old StridedSlice layers # 1. Remove connections from input_data to StridedSlice ops out_data_nodes = [] name_for_future_split = out_nodes[0].name for node in out_nodes: out_data_nodes.append(node.out_node()) graph.remove_edge(input_data.id, node.id) graph.remove_edge(node.id, node.out_node().id) graph.remove_node(node.id) log.debug("Removed: {}".format(node.id)) # 2. Create Split layer and reorder outputs name = name_for_future_split + "/Split" axis_const = Const(graph, {'value': int64_array(split_channel_dim), 'name': name + '/Axis'}).create_node_with_data() size_splits_const = Const(graph, {'value': int64_array(size_splits), 'name': name + '/Sizes'}).create_node_with_data() split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits))) split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const], data_nodes=final_data_nodes_list)