def replace_pattern(self, graph: Graph, match: dict): conv = match['conv'] stb = match['space_to_batch'] bts = match['batch_to_space'] block_size = match['stb_bs'] input = match['input'] output = match['output'] stb_out = match['stb_output'] conv_out = match['conv_output'] in_edge_attrs = graph.get_edge_data(input.id, stb.id)[0] out_edge_attrs = graph.get_edge_data(bts.id, output.id)[0] graph.remove_edge(input.id, stb.id) graph.remove_edge(stb_out.id, conv.id) graph.remove_edge(conv.id, conv_out.id) graph.remove_edge(bts.id, output.id) conv.dilation[conv.spatial_dims] = block_size.value[conv.spatial_dims] pad_begin = match['stb_pad_begin'].value - match['bts_crop_begin'].value pad_end = match['stb_pad_end'].value - match['bts_crop_end'].value conv.pad[conv.spatial_dims] = [[pad_begin[x], pad_end[x]] for x in conv.spatial_dims] conv['auto_pad'] = None graph.add_edges_from([ (input.id, conv.id, {'in': 0, **in_edge_attrs}), (conv.id, output.id, {'out': 0, **out_edge_attrs}), ])
def _insert_pooling(graph: Graph, first_node: Node, second_node: Node, spatial_dims): """ This function inserts point wise pooling layer between two nodes """ log.debug("STRIDE PROP: Insert pooling between {} and {}".format( first_node.name, second_node.name)) stride_prop = second_node.stride_prop assert len(graph.get_edge_data(first_node.id, second_node.id)) == 1 eattrs = graph.get_edge_data(first_node.id, second_node.id)[0] graph.remove_edge(first_node.id, second_node.id) pooling = Pooling( graph, dict(name='Pooling_', spatial_dims=spatial_dims, window=np.array([1, 1, 1, 1]), output_spatial_shape=None, stride=np.array(stride_prop), pad_spatial_shape=np.array([[0, 0], [0, 0]]), pad=np.array([[0, 0], [0, 0], [0, 0], [0, 0]]), pool_method='max', is_partial_inferred=False)) pooling_data = pooling.create_node_with_data([first_node]) _clean_fw_tensor_attrs(pooling_data) graph.add_edges_from([(pooling_data.id, second_node.id, eattrs)])
def reordering_inputs(graph: Graph, match: dict): """ Reorder (renumbering) inputs to described format. We need to renumber initial states ports. """ rnn_layer = match['rnn_layer'] assert 5 in rnn_layer.in_nodes() hidden_state_edge = graph.get_edge_data(rnn_layer.in_node(5).id, rnn_layer.id) hidden_state_edge[0]['in'] = 4 if rnn_layer.op == 'LSTM': assert 6 in rnn_layer.in_nodes() cell_state_edge = graph.get_edge_data(rnn_layer.in_node(6).id, rnn_layer.id) cell_state_edge[0]['in'] = 5
def replace_sub_graph(self, graph: Graph, match: dict): """ Need to find the pattern: SoftmaxActivation -> DetectionOutput DetectionOutput in IE expects flattened input from SoftMax, that is why there is the need to add Flatten layer Parameters ---------- graph : Graph Graph with loaded model. match : dict Patterns which were found in graph structure. """ softmax_activation = match['softmax_activation'] multi_box_detection = match['multi_box_detection'] softmax_activation['axis'] = -1 edge_data = graph.get_edge_data(softmax_activation.id, multi_box_detection.id) out_port = edge_data[0]['out'] in_port = edge_data[0]['in'] graph.remove_edge(softmax_activation.id, multi_box_detection.id) new_reshape_node = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), dict(op='Reshape', name=multi_box_detection.name + '/Reshape_'), softmax_activation) graph.create_edge(new_reshape_node, multi_box_detection, in_port=in_port, out_port=out_port)
def permute_data_nodes_attrs(graph: Graph): # Iterate over all data nodes and apply permutation if exists for node in graph.get_data_nodes(): if not node.has_valid('permutation') or \ all([attrs.get('input_permutation', False) for u, v, attrs in graph.out_edges(node.id, data=True)]): continue if len( node.in_nodes() ) != 0: # there are data nodes without input operation node inside the TensorIterator edge_attrs = graph.get_edge_data(node.in_node(0).id, node.id)[0] if is_output_data_in_correct_layout(node.in_node(0), edge_attrs['out']): log.debug( 'Do not permute data node attrs for node "{}" output port "{}"' .format(node.in_node(0).id, edge_attrs['out'])) continue # Apply permutation for shape and value if exists if len(node.permutation.perm) == 0: continue node.shape = shape_array(node.shape)[node.permutation.perm] if node.has_valid('value'): assert len(node.value.shape) == len(node.permutation.perm), \ 'Node {} has shape {} and permutation {} that does not match. Their lengths should be equal' \ ''.format(node.name, node.value.shape, node.permutation.perm) node.value = mo_array( node.value.transpose(node.permutation.perm))
def add_reshape_after_data_node(graph: Graph, data_node_name: str): """ Adds reshape operation which changes shape of the tensor produced by TFSubgraphCall from 4D to real dimension of the tensor. The data_node_name node contains real dimensions of the tensor but they will be changed in the add_reshapes_for_tf_subgraph_calls function to a 4D because IE TF call layer supports output in 4D only. :param graph: graph to operate on. :param data_node_name: name of the data node to be reshaped to correct dimensions. :return: None """ data_node = Node(graph, data_node_name) # if the data node was previously marked as output then we need to mark as output new reshaped data node is_out_node = False if len(data_node.out_nodes()) == 1 and data_node.out_node().has('op') and data_node.out_node().op == 'Result': is_out_node = True graph.remove_node(data_node.out_node().id) # save old consumers nodes with edge attributes old_consumer_nodes_with_attrs = list() for index, out_op in enumerate(data_node.out_nodes()): edge_attrs = graph.get_edge_data(data_node_name, out_op.name)[0] old_consumer_nodes_with_attrs.append((out_op.name, edge_attrs)) # remove old consumers from the data node for out_op in list(data_node.out_nodes()): graph.remove_edge(data_node_name, out_op.name) # reshape operation node reshape_node_name = graph.unique_id("Reshape_") graph.add_node(reshape_node_name, kind='op', type='Reshape', name=reshape_node_name, op='Reshape', data_type=data_node['data_type']) update_ie_fields(graph.node[reshape_node_name]) # reshape shape data node reshape_shape_data_node_name = graph.unique_id("Reshape_shape_") graph.add_node(reshape_shape_data_node_name, kind='data', name=reshape_shape_data_node_name, value=np.array(data_node['shape']), shape=[1]) # reshaped data node reshaped_value = None if data_node['value'] is not None: reshaped_value = np.array(data_node['value']) reshaped_data_node_name = graph.unique_id("reshaped_data_") graph.add_node(reshaped_data_node_name, kind='data', name=reshaped_data_node_name, shape=np.array(data_node['shape']), value=reshaped_value, nchw_layout=True) if is_out_node: add_opoutput(graph, reshaped_data_node_name, 0, False) graph.add_edges_from([ (data_node_name, reshape_node_name, {'in': 0}), (reshape_shape_data_node_name, reshape_node_name, {'in': 1}), (reshape_node_name, reshaped_data_node_name, {'out': 0}), ]) for out_node_name, edge_attrs in old_consumer_nodes_with_attrs: graph.add_edges_from([ (reshaped_data_node_name, out_node_name, edge_attrs) ])
def find_and_replace_pattern(self, graph: Graph): mp = {} used = {} for node in graph.get_op_nodes(type='Concat'): in_nodes = tuple( [node.in_node(idx).id for idx in range(len(node.in_nodes()))]) out_node = (node.id, node.out_node().id) if in_nodes in mp: log.warning("Something is weird! {} and {}".format( node.id, mp[in_nodes])) else: mp.update({in_nodes: out_node}) used.update({node.id: {x: False for x in in_nodes}}) for key in mp.keys(): replacers = [] for i in range(len(key)): for j in range(i + 1, len(key)): arr = tuple(key[i:j + 1]) if arr in mp.keys() and arr != key: replacers.append((len(arr), arr)) replacers.sort(reverse=True) concat_id = mp[key][0] for ln, arr in replacers: # Check that we can do it!!! we_can = True for x in arr: if used[concat_id][x]: we_can = False break if not we_can: continue for x in arr: used[concat_id][x] = True edge_attrs = graph.get_edge_data(arr[0], concat_id)[0] for in_node in arr: graph.remove_edge(in_node, concat_id) new_input = mp[arr][1] out_port = len(Node(graph, new_input).out_nodes()) + 1 edge_attrs['out'] = out_port graph.add_edge(new_input, concat_id, **edge_attrs) # Renumber 'in' attrs concat_node = Node(graph, concat_id) ln = len(concat_node.in_nodes()) ports = [x for x in concat_node.in_nodes().keys()] ports.sort() p_id = 0 for p in ports: in_node = concat_node.in_nodes()[p] graph[in_node.id][concat_id][0]['in'] = p_id p_id += 1
def check_init_states(graph: Graph, match: dict): """ Check if cell have initial states and create zeros states if not. And renumber ports for this states. """ rnn_cell = match['rnn_layer'] num_directions = 2 if rnn_cell.direction == 'bidirectional' else 1 batch_size = rnn_cell.in_node(0).shape[rnn_cell.batch_dim] h_init_port = 5 c_init_port = 6 if 2 not in rnn_cell.in_nodes(): h_shape = [num_directions, batch_size, rnn_cell.hidden_size] # from ONNX spec h_init = np.full(h_shape, 0, dtype=np.float32) Op.create_and_connect_input_data_node( graph, rnn_cell, { 'value': h_init, 'shape': int64_array(h_init.shape) }, { 'in': h_init_port, 'permutation': None }) else: hidden_state_edge = graph.get_edge_data( rnn_cell.in_node(2).id, rnn_cell.id) hidden_state_edge[0]['in'] = h_init_port if rnn_cell.op == 'LSTM': if 3 not in rnn_cell.in_nodes(): c_shape = [num_directions, batch_size, rnn_cell.hidden_size] # from ONNX spec c_init = np.full(c_shape, 0, dtype=np.float32) Op.create_and_connect_input_data_node( graph, rnn_cell, { 'value': c_init, 'shape': int64_array(c_init.shape) }, { 'in': c_init_port, 'permutation': None }) else: cell_state_edge = graph.get_edge_data( rnn_cell.in_node(3).id, rnn_cell.id) cell_state_edge[0]['in'] = c_init_port
def replace_sub_graph(self, graph: Graph, match: dict): node = match['softmax'] if 'temperature' in node and node['temperature'] != 1.0: in_node = node.in_node() out_nodes = [node for node in node.out_nodes().values()] graph.remove_edge(node.in_node().id, node.id) temperature = mo_array([1.0 / node.temperature]) scalar_value_op = Const(graph, dict(value=temperature, shape=temperature.shape, symbol_dict={'name': node.id + '/const'})) mul_op = Mul(graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'})) mul_node = mul_op.create_node(inputs=[in_node, scalar_value_op.create_node()]) edge_attrs = graph.get_edge_data(node.id, out_nodes[0].id)[0] graph.add_edges_from([(mul_node.id, node.id, edge_attrs)])
def replace_pattern(graph: Graph, match: dict): """ DetectionOutput layer has another order of inputs unlike mxnet. Need to reorder _contrib_MultiBoxDetection inputs for correct conversion to DetectionOutput layer. Parameters ---------- graph : Graph Graph with loaded model. """ multi_box_detection_node = match['multi_box_detection'] conf_node = multi_box_detection_node.in_node(0) loc_node = multi_box_detection_node.in_node(1) conf_edge_data = graph.get_edge_data(conf_node.id, multi_box_detection_node.id) conf_out_port = conf_edge_data[0]['out'] conf_in_port = conf_edge_data[0]['in'] loc_edge_data = graph.get_edge_data(loc_node.id, multi_box_detection_node.id) loc_out_port = loc_edge_data[0]['out'] loc_in_port = loc_edge_data[0]['in'] graph.remove_edge(conf_node.id, multi_box_detection_node.id) graph.remove_edge(loc_node.id, multi_box_detection_node.id) graph.create_edge(loc_node, multi_box_detection_node, in_port=conf_in_port, out_port=conf_out_port) graph.create_edge(conf_node, multi_box_detection_node, in_port=loc_in_port, out_port=loc_out_port)
def pad_op_transform(graph: Graph, match: dict): op = match['op'] pad_op = match['pad_op'] input_data = pad_op.in_node(0) if pad_op.mode != 'constant': log.info( 'The pad node "{}" with pad mode "{}" cannot be fused.'.format( pad_op.soft_get('name'), pad_op.mode)) return if op.type == 'Pooling' and op.pool_method == 'max': return if pad_op.mode == 'constant': fill_value = pad_op.in_port(3).data.get_value() if fill_value is None or fill_value != 0.0: log.info( 'The pad node "{}" with non-zero fill value cannot be fused.'. format(pad_op.soft_get('name'))) return input_tensor_dims = len(match['pad_output'].shape) for in_port in [1, 2]: pads = pad_op.in_port(in_port).data.get_value() if pads[get_features_dim(op.graph.graph['layout'], input_tensor_dims)] != 0 or \ pads[get_batch_dim(op.graph.graph['layout'], input_tensor_dims)] != 0: log.info( 'The pad node "{}" with padding over feature/batch dimension cannot be fused.' .format(pad_op.soft_get('name'))) return op.pad += np.concatenate([ pad_op.in_port(1).data.get_value().reshape([-1, 1]), pad_op.in_port(2).data.get_value().reshape([-1, 1]) ], axis=1) op.pad_spatial_shape = op.pad[op.spatial_dims] op['auto_pad'] = None if op.type == 'Pooling': op['exclude_pad'] = False assert (graph[match['pad_output'].node][match['op'].node][0]['in'] == 0) edge_attrs = graph.get_edge_data(match['pad_output'].id, match['op'].id)[0] graph.remove_edge(match['pad_output'].id, match['op'].id) graph.add_edge(input_data.id, match['op'].id, **{'in': 0, **edge_attrs})
def replace_pattern(graph: Graph, match: dict): node = match['op'] if not node.has_port('in', 2) or node.in_port(2).disconnected() or not node.has_and_set('shape_input'): return if node.has_valid('layout') and not node.layout.startswith('NC') and graph.graph['layout'] == 'NCHW': input_shape_rank = len(node.in_port(0).data.get_shape()) permutation = PermuteAttrs.get_nhwc_to_nchw_permutation(input_shape_rank) data_node = node.in_node(2) name = node.soft_get('name', node.id) + '/ShapeGather' const = Const(graph, {'value': permutation.perm, 'name': name + '/Const', 'need_shape_inference': True}).create_node_with_data() axis_const = Const(graph, {'value': int64_array(0), 'name': name + '/Axis'}).create_node_with_data() gather = Gather(graph, {'name': name, 'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const]) attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy() graph.add_edge(gather.id, node.id, **attrs) graph.remove_edge(data_node.id, node.id)
def find_and_replace_pattern(self, graph: Graph): # Iterate over all data nodes and find all with >= 1 consumers for input_data in list(graph.get_data_nodes()): # We don't use constant data nodes if input_data.value is not None: continue if input_data.shape is None: continue input_shape = shape_array(input_data.shape) # Get all unique StridedSlice consumers out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).id == input_data.id] if len(out_nodes) <= 1: continue valid_for_replacement = True for n in out_nodes: if any(not isinstance(s, slice) for s in n.slices): # this is a slice with dynamic dimension. Such operation is not valid for replacement valid_for_replacement = False if not valid_for_replacement: continue sorted_out_nodes = sorted(out_nodes, key=lambda n: list(n.slices)) out_nodes = unique_by(sorted_out_nodes, strided_slices_equality) for node in out_nodes: if len(node.slices) != len(out_nodes[0].slices): valid_for_replacement = False # Detect dimension for splitting split_channel_dim = None for dim_id, s in enumerate(out_nodes[0].slices): l, r, stride = s.start, s.stop, s.step # if both l and r are None then the dimension is not sliced if (l != 0 or r != input_shape[dim_id]) and (l is not None or r is not None): if split_channel_dim is None: split_channel_dim = dim_id else: valid_for_replacement = False if split_channel_dim is None: valid_for_replacement = False # split_dims contains tuples with split range and output data node split_dims = [] for out_id, node in enumerate(out_nodes): # Check that StridedSlice op has stride eq 1 and splits only feature channel for id, s in enumerate(node.slices): l, r, stride = s.start, s.stop, s.step # We don't support StridedSlice with stride != 1 if stride != 1: valid_for_replacement = False if id == split_channel_dim: split_dims.append((s.start, s.stop, node.out_node())) if not valid_for_replacement: continue # Check feature split intersection final_data_nodes_list = [] sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1])) # check if we have similar StridedSlice operations with different outputs prev_sd = sorted_split_dims[0] to_remove = [] for i in range(1, len(sorted_split_dims)): if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name: cur_node = sorted_split_dims[i][2] for out in cur_node.out_nodes(): attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0]) graph.remove_edge(cur_node.id, out.id) graph.add_edge(prev_sd[2].id, out.id, **attrs) to_remove.append(i) for ind in reversed(to_remove): sorted_split_dims.pop(ind) size_splits = [] prev_r = 0 for l, r, out in sorted_split_dims: # Split dims shouldn't intersect if l < prev_r: valid_for_replacement = False prev_r = r if prev_r > input_shape[split_channel_dim]: valid_for_replacement = False if not valid_for_replacement: continue prev_r = 0 for l, r, out in sorted_split_dims: # Save missing tensor part if l > prev_r: shape = mo_array(input_shape) size_splits.append(l - prev_r) shape[split_channel_dim] = l - prev_r data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape}) add_opoutput(graph, data_node.id, 0, False, keep_output_port=True) final_data_nodes_list.append(data_node) prev_r = r size_splits.append(r - l) final_data_nodes_list.append(out) if prev_r < input_shape[split_channel_dim]: # Add last part of tensor shape = input_shape.copy() shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r size_splits.append(input_shape[split_channel_dim] - prev_r) data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape}) add_opoutput(graph, data_node.id, 0, False, keep_output_port=True) final_data_nodes_list.append(data_node) for node in out_nodes: if not np.all([x == 0 for x in node.shrink_axis_mask]): out_node = node.out_node() if np.any(node['shrink_axis_mask']): self.add_squeeze_for_shrink(graph, node) if np.any(node['new_axis_mask']): self.add_unsqueeze_for_new(graph, node) for i in range(len(final_data_nodes_list)): if final_data_nodes_list[i].name == out_node.name: final_data_nodes_list[i] = node.out_node() break # Insert Split layer and remove old StridedSlice layers # 1. Remove connections from input_data to StridedSlice ops out_data_nodes = [] name_for_future_split = out_nodes[0].name for node in out_nodes: out_data_nodes.append(node.out_node()) graph.remove_edge(input_data.id, node.id) graph.remove_edge(node.id, node.out_node().id) graph.remove_node(node.id) log.debug("Removed: {}".format(node.id)) # 2. Create Split layer and reorder outputs name = name_for_future_split + "/Split" axis_const = Const(graph, {'value': int64_array(split_channel_dim), 'name': name + '/Axis'}).create_node_with_data() size_splits_const = Const(graph, {'value': int64_array(size_splits), 'name': name + '/Sizes'}).create_node_with_data() split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits))) split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const], data_nodes=final_data_nodes_list)
def build_graph_with_attrs(nodes_with_attrs: list, edges_with_attrs: list, new_nodes_with_attrs: list = [], new_edges_with_attrs: list = [], update_edge_attrs: dict = None, update_nodes_attributes: list = None, nodes_with_edges_only: bool = False, add_nodes_from_edges: bool = False): """ Build the Graph with specific nodes and edges. Also update of edge and node parameters is supported. :param nodes_with_attrs: list of tuples ('node_name', {node_attrs}) :param edges_with_attrs: list of tuples like (start node, end node, (optional) {attrs of the edge}). :param new_nodes_with_attrs: analogically nodes_with_attrs :param new_edges_with_attrs: analogically new_edges :param update_edge_attrs: optional dictionary like {('from_node', 'to_node', key): {edge_attrs}}. :param update_nodes_attributes: optional list of tuples which specifies nodes names and their attributes to be updated. The first element is a node name to update attribute and the second element is a dictionary with attribute name and its value. :param nodes_with_edges_only: add nodes which has at least one incoming or outcoming edge. :param add_nodes_from_edges: whether nodes that is not listed in all_nodes but are in all_edges is allowed. :return: generated graph. """ if not_all_new([node[0] for node in nodes_with_attrs], [node[0] for node in new_nodes_with_attrs]): raise Error( 'Some nodes from new_nodes_with_attrs are already in nodes.' ' Please, add to new_nodes_with_attrs only NEW nodes.') if not_all_new([(edge[0], edge[1]) for edge in edges_with_attrs], [(edge[0], edge[1]) for edge in new_edges_with_attrs]): raise Error( 'Some edges from new_edges_with_attrs are already in edges.' ' Please, add to new_edges_with_attrs only NEW edges.') # Check that all nodes from list of edges are in nodes all_nodes = nodes_with_attrs + new_nodes_with_attrs all_edges = edges_with_attrs + new_edges_with_attrs all_nodes_names = [node[0] for node in all_nodes] if not add_nodes_from_edges and not all_edges_in_nodes( nodes=all_nodes_names, edges=all_edges): raise Error( "Some nodes from list of edges is not in nodes. Please, add all necessary nodes." ) graph = Graph() # Create dict for nodes with attrs nodes_attrs = {} for node_name, attrs in all_nodes: nodes_attrs[node_name] = attrs if 'name' not in attrs: attrs['name'] = node_name if nodes_with_edges_only: # filter nodes to keep only ones with edges connected filtered_nodes = {} for edge in all_edges: node_1, node_2 = edge[0], edge[1] filtered_nodes[node_1] = nodes_attrs[node_1] filtered_nodes[node_2] = nodes_attrs[node_2] nodes_attrs = filtered_nodes # Create all nodes for node, attrs in nodes_attrs.items(): graph.add_node(node, **deepcopy(attrs)) # Connect nodes with edges (also unpack edge params) for edge in all_edges: node_1, node_2 = edge[0], edge[1] edge_attrs = edge[2] if len(edge) == 3 else {} graph.add_edge(node_1, node_2, **edge_attrs) # Update attributes of edges if update_edge_attrs: # it will work in 2.x networkx only for edge, attr in update_edge_attrs.items(): for k, v in attr.items(): nx.set_edge_attributes(G=graph, name=k, values={edge: v}) # Update attributes of nodes if update_nodes_attributes is not None: for node_name, new_attrs in update_nodes_attributes: assert (node_name in graph.nodes()) for attr, value in new_attrs.items(): graph.node[node_name][attr] = value for node_id in graph.nodes(): node = Node(graph, node_id) check_and_update_ports(node, [ graph.get_edge_data(edge[0], node_id)[0] for edge in graph.in_edges(node_id) ], True) check_and_update_ports(node, [ graph.get_edge_data(node_id, edge[1])[0] for edge in graph.out_edges(node_id) ], False) for node in graph.get_op_nodes(): # Add in_ports attribute in_edges = node.in_edges() for i in range(len(in_edges)): node.add_input_port(idx=i) # Add out_ports attribute out_edges = node.out_edges() for i in range(len(out_edges)): node.add_output_port(idx=i) return graph
def find_and_replace_pattern(self, graph: Graph): for node in list(graph.nodes()): node = Node(graph, node) node_name = node.soft_get('name', node.id) # Check that node layout mismatch with graph layout # For example: NHWC and NCHW or NCDHW and NDHWC if node.kind == 'op' and node.has_valid( 'layout') and node.layout != indices_mapping[len( node.layout)][graph.graph['layout']]: input = node.in_node() output = node.out_node() # Calculate permutation for further Transpose operations if graph.graph['layout'] == 'NCHW': # if Node has NCHW and graph has NHWC layout permutation = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.layout)) else: # if Node has NHWC and graph has NCHW layout permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(node.layout)) # Schematic representation of transformation below # # \ NCHW NCHW # NHWC -- \ | permutation permutation | # data-->Convolution(example)-->data -- / | | NCHW | | # / data->Transpose->data->Convolution->data->Transpose->data # 1. Insert input Transpose # This Transpose will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input.id, node.id)[0] graph.remove_edge(input.id, node.id) input_permute_name = node_name + '/input_transpose' input_order_const = Const( graph, { 'name': input_permute_name + '/order', 'value': permutation.perm }).create_node_with_data() input_permute_op = Transpose(graph, {'name': input_permute_name}) input_permute_data_node = input_permute_op.create_node_with_data( [input, input_order_const]) graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs) # 2. Insert output Transpose # This Transpose will permute output from operation layout to original input layout edge_attrs = graph.get_edge_data(node.id, output.id)[0] graph.remove_edge(node.id, output.id) input_data_node = Op.create_data_node( graph, node, {'shape': output.shape[permutation.perm]}, edge_attrs) output_permute_name = node_name + '/output_transpose' output_order_const = Const( graph, { 'name': output_permute_name + '/order', 'value': permutation.inv }).create_node_with_data() output_permute_op = Transpose(graph, { 'name': output_permute_name }).create_node_with_data([input_data_node, output_order_const], data_nodes=output) # 3. Add permutations for Node # Here we use permutation mechanism where data nodes takes permutation attribute. # And then we call permute_attrs method that permutes node attributes according to permutations on # data nodes. node.in_node()['permutation'] = permutation node.out_node()['permutation'] = permutation node.permute_attrs.permute_attrs(node) node.in_node()['permutation'] = None node.out_node()['permutation'] = None
def replace_pattern(graph: Graph, match: dict): time_len = match['concatenated_hidden_states'].shape[0] r""" Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have concatenated cell states output and if we can not collapse it, then we does not support this type of BlockLSTM We simplify the sub-graph below by taking another output of BlockLSTM: concatenated cell states over the whole time sequence -> last cell state BlockLSTM || out 1 (concatenated cell states coming out of BlockLSTM) \/ in 1 ConcatV2 || (concatenation with initial state or another unused data) \/ Reshape || \/ Gather (taking the last cell state from previous BlockLSTM, if Gather indexes == time_len) """ # check that there are no other consumers of concatenated_cell_states_data data flow valid_output_names = [ 'concat_1', 'concat_1_data', 'reshape_1', 'reshape_1_data', 'gather_1', 'gather_1_data' ] valid_output_node_ids = [match[name].id for name in valid_output_names] node_names_to_check_outputs = [ 'concatenated_cell_states_data', 'concat_1_data', 'reshape_1_data' ] for name in node_names_to_check_outputs: for node in match[name].out_nodes(): if node.id not in valid_output_node_ids: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) # check that we really take the last cell state data by Gather gather_indexes = match['gather_1'].in_node(1).value if len(gather_indexes) == 1: gather_index = gather_indexes[0] else: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) if gather_index != time_len: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) """ We passed #1 and #2 stages from class description. It means that we can translate the rest of the pattern to LSTMSequence even without following optimizations """ node = match['BlockLSTM'] weights_node = node.in_node(1) biases_node = node.in_node(2) shift_const = node.forget_bias # Assign temporary shape for them for easier manipulation # TF stores weights in IO order input_size = node.in_node(0).shape[-1] hidden_size = node.in_node(3).shape[-1] weights = weights_node.value biases = biases_node.value assert weights.shape[0] == input_size + hidden_size, \ "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size) assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \ "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size) weights = weights.reshape([ weights.shape[0], 4, # gates hidden_size ]) biases = biases.reshape([ 4, # gates hidden_size ]) # Reorder gates icfo --> fico for both weights and biases gate_reorder = [2, 0, 1, 3] weights = np.take(weights, gate_reorder, axis=1) biases = np.take(biases, gate_reorder, axis=0) # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0) # Note: in case of moving this code up before gate reordering, the addition # should be applied at different place biases[0] += shift_const # Return to the original shapes weights = weights.reshape([weights.shape[0], -1]) biases = biases.flatten() # TF stores weights in IO, but IE requires it in OI: transpose weights = weights.transpose() weights_node.value = weights weights_node.shape = int64_array(weights.shape) biases_node.value = biases biases_node.shape = int64_array(biases.shape) attrs = dict( graph.get_edge_data(match['gather_1'].id, match['gather_1_data'].id)[0]) attrs.update({'out': 2}) graph.remove_edge(match['BlockLSTM'].id, match['concatenated_cell_states_data'].id) graph.remove_edge(match['gather_1'].id, match['gather_1_data'].id) match['BlockLSTM'].add_output_port(attrs['out']) graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id, **attrs) """ #3 Renumbering h_init_state, c_init_state input ports to match RNNSequence ports order. """ h_init_port = 4 c_init_port = 5 # c_init_state if 4 in node.in_nodes(): assert c_init_port not in node.in_nodes() cell_state_edge = graph.get_edge_data(node.in_node(4).id, node.id) cell_state_edge[0]['in'] = c_init_port #h_init_state if 3 in node.in_nodes(): assert h_init_port not in node.in_nodes() hidden_state_edge = graph.get_edge_data( node.in_node(3).id, node.id) hidden_state_edge[0]['in'] = h_init_port new_attrs = { 'sequence_dim': 0, 'batch_dim': 1, 'direction': 'forward', 'hidden_size': match['concatenated_hidden_states'].shape[-1], 'format': 'tf', } LSTM.update_node_stat(match['BlockLSTM'], new_attrs) """ Optional #4 optimization from class description following """ data_to_mul = [ n for n in match['mul'].in_nodes().values() if n.id != match['concatenated_hidden_states'].id ] if len(data_to_mul) != 1: return # unexpected type of mul data_to_mul = data_to_mul[0] if not data_to_mul.has_valid('value'): return # unexpected type of mul data_to_mul_value = data_to_mul.value if not np.all(data_to_mul_value == 1): return # unexpected type of mul # remove useless mul attrs = dict( graph.get_edge_data(match['BlockLSTM'].id, match['concatenated_hidden_states'].id)[0]) graph.remove_edge(match['BlockLSTM'].id, match['concatenated_hidden_states'].id) graph.remove_edge(match['mul'].id, match['mul_data'].id) graph.add_edge(match['BlockLSTM'].id, match['mul_data'].id, **attrs) # find true usages of concatenated hidden states data (not last hidden state) valid_output_names = [ 'mul_data', 'concat_0', 'concat_0_data', 'reshape_0', 'reshape_0_data', 'gather_0', 'gather_0_data' ] valid_output_node_ids = [match[name].id for name in valid_output_names] node_names_to_check_outputs = [ 'mul_data', 'concat_0_data', 'reshape_0_data' ] list_of_concatenated_hidden_states_children_node_ids = [] for name in node_names_to_check_outputs: for node in match[name].out_nodes(): if node.id not in valid_output_node_ids: list_of_concatenated_hidden_states_children_node_ids.append( node.id) if len(list_of_concatenated_hidden_states_children_node_ids) != 1: return # not supported placement of pattern conacenated_child_node_id = list_of_concatenated_hidden_states_children_node_ids[ 0] if conacenated_child_node_id != match[ 'after_mul_op_to_the_rest_of_model'].id: return # not supported placement of pattern gather_indexes = match['gather_0'].in_node(1).value if len(gather_indexes) == 1: gather_index = gather_indexes[0] else: return # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is if gather_index != time_len: return # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is attrs = dict( graph.get_edge_data(match['gather_0'].id, match['gather_0_data'].id)[0]) attrs.update({'out': 1}) graph.remove_edge(match['mul_data'].id, match['concat_0'].id) graph.remove_edge(match['gather_0'].id, match['gather_0_data'].id) graph.add_edge(match['BlockLSTM'].id, match['gather_0_data'].id, **attrs)