def add_reshape_after_data_node(graph: Graph, data_node_name: str): """ Adds reshape operation which changes shape of the tensor produced by TFSubgraphCall from 4D to real dimension of the tensor. The data_node_name node contains real dimensions of the tensor but they will be changed in the add_reshapes_for_tf_subgraph_calls function to a 4D because IE TF call layer supports output in 4D only. :param graph: graph to operate on. :param data_node_name: name of the data node to be reshaped to correct dimensions. :return: None """ data_node = Node(graph, data_node_name) # if the data node was previously marked as output then we need to mark as output new reshaped data node is_out_node = False if len(data_node.out_nodes()) == 1 and data_node.out_node().has('op') and data_node.out_node().op == 'OpOutput': is_out_node = True graph.remove_node(data_node.out_node().id) # save old consumers nodes with edge attributes old_consumer_nodes_with_attrs = list() for index, out_op in enumerate(data_node.out_nodes()): edge_attrs = graph.get_edge_data(data_node_name, out_op.name)[0] old_consumer_nodes_with_attrs.append((out_op.name, edge_attrs)) # remove old consumers from the data node for out_op in list(data_node.out_nodes()): graph.remove_edge(data_node_name, out_op.name) # reshape operation node reshape_node_name = graph.unique_id("Reshape_") graph.add_node(reshape_node_name, kind='op', precision="FP32", type='Reshape', name=reshape_node_name, op='Reshape', data_type=data_node['data_type']) update_ie_fields(graph.node[reshape_node_name]) # reshape shape data node reshape_shape_data_node_name = graph.unique_id("Reshape_shape_") graph.add_node(reshape_shape_data_node_name, kind='data', precision="FP32", name=reshape_shape_data_node_name, value=np.array(data_node['shape']), shape=[1]) # reshaped data node reshaped_value = None if data_node['value'] is not None: reshaped_value = np.array(data_node['value']) reshaped_data_node_name = graph.unique_id("reshaped_data_") graph.add_node(reshaped_data_node_name, kind='data', precision="FP32", name=reshaped_data_node_name, shape=np.array(data_node['shape']), value=reshaped_value, nchw_layout=True) if is_out_node: add_opoutput(graph, reshaped_data_node_name, 0, False) graph.add_edges_from([ (data_node_name, reshape_node_name, {'in': 0}), (reshape_shape_data_node_name, reshape_node_name, {'in': 1}), (reshape_node_name, reshaped_data_node_name, {'out': 0}), ]) for out_node_name, edge_attrs in old_consumer_nodes_with_attrs: graph.add_edges_from([ (reshaped_data_node_name, out_node_name, edge_attrs) ])
def infer(node: Node): # there are limitations coming from ONNX LSTM definition and normalization rules assert len(node.in_nodes()) >= 3 # X, W and R assert len(node.in_nodes()) <= 7 assert len(node.out_nodes()) <= 3 assert node.batch_dim <= 1 assert node.sequence_dim <= 1 assert node.batch_dim != node.sequence_dim assert node.direction in ['forward', 'reverse', 'bidirectional'] if node.blobs_wrb: mark_input_bins(node, ['W', 'R', 'B']) else: mark_input_bins(node) input_shape = node.in_node(0).shape assert len(input_shape) == 3 for port in [2, 3]: if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \ 'zero_shapes' in node.in_node(port).in_node(): for i in node.in_node(port).in_node().zero_shapes: if node.in_node(port).shape[i] != input_shape[i]: node.in_node(port).value = np.repeat( node.in_node(port).value, input_shape[i], axis=i) node.in_node(port).shape[i] = input_shape[i] out_shape = np.array([ input_shape[node.sequence_dim], input_shape[node.batch_dim], node.hidden_size ], dtype=np.int64) assert not node.has_num_directions or node.sequence_dim == 0, \ 'If has_num_directions == True, then node.sequence_dim should be equal 0, but it is {}'.format( node.sequence_dim) num_directions = 2 if node.direction in ['bidirectional'] else 1 num_layers = node.num_layers if node.has_num_directions: # insert extra dimension to output shape for num_directions out_shape = np.insert(out_shape, 1, np.int64(num_directions)) node.out_node(0).shape = out_shape # extra outputs for hidden/cell states state_size = np.array([input_shape[1], node.hidden_size], dtype=np.int64) if node.has_num_directions: state_size = np.insert(state_size, 0, num_directions * num_layers) for i in [1, 2]: if i not in node.out_nodes(): data_node = Op._create_data_node(node.graph, name=node.node + '/ExtraOutput/' + str(i), attrs={'executable': True}) node.graph.add_edge(node.id, data_node.id, key=0, out=i) add_opoutput(node.graph, data_node.id, 0, False) else: data_node = node.out_node(i) data_node.shape = state_size.copy()
def update_body_graph(body_graph: Graph, subgraph_proto: dict, body_parameter_names: list, body_results: list): """ Updates the loop body graph with a sub-graph (for body or condition functions) :param body_graph: a loop body graph to be updated :param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph :param body_parameter_names: a (unchanged) list of parameters in the loop body graph :param body_results: a list of Result nodes that is extended with a list from a sub-graph """ # create a map from a node name in original model to a name in a loop body graph assuming # that names in the original model are unique # initially, the map contains names for parameters that are common for the body and condition graphs map_original_name = {} for idx, pb_node in enumerate(subgraph_proto['input_arg']): map_original_name[pb_node.name] = body_parameter_names[idx] # walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph for pb_node in subgraph_proto['node_def']: # create an NX node id = body_graph.unique_id(pb_node.name) map_original_name[pb_node.name] = id body_graph.add_node(id, pb=pb_node, kind='op') if hasattr(body_graph, 'op_names_statistic') and hasattr( pb_node, 'op'): body_graph.op_names_statistic[pb_node.op] += 1 # add incoming edges based on data_nodes_map for dst_port, inp in enumerate(pb_node.input): orig_src_id = inp.split(":")[0] # TODO: avoid this temporal workaround for TF 2.4 or higher RNN layers: # skip control flow dependency if orig_src_id[0] == '^': continue src_id = map_original_name[orig_src_id] src_port = 0 if len(inp.split(":")) == 1 else int( inp.split(":")[-1]) assert (body_graph.has_node(src_id)) body_graph.add_edges_from( [create_tf_edge(src_id + ":" + str(src_port), id, dst_port)]) # create Result nodes in the loop body graph for output in subgraph_proto['output_arg']: output_name = subgraph_proto['ret'][output.name] orig_src_id = output_name.split(":")[0] src_id = map_original_name[orig_src_id] src_port = 0 if len(output_name.split(":")) == 1 \ else int(output_name.split(":")[-1]) assert body_graph.has_node( src_id ), 'The body graph does not contain output with name "{}"'.format( src_id) body_results.append( Node(body_graph, add_opoutput(body_graph, src_id, src_port, False))) return True
def update_body_graph(body_graph: Graph, subgraph_proto: dict, body_parameter_names: list, body_results: list): """ Updates the loop body graph with a sub-graph (for body or condition functions) :param body_graph: a loop body graph to be updated :param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph :param body_parameter_names: a (unchanged) list of parameters in the loop body graph :param body_results: a list of Result nodes that is extended with a list from a sub-graph """ # create a map from a node name in original model to a name in a loop body graph assuming # that names in the original model are unique # initially, the map contains names for parameters that are common for the body and condition graphs map_original_name = {} for idx, pb_node in enumerate(subgraph_proto['input_arg']): map_original_name[pb_node.name] = body_parameter_names[idx] # walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph for pb_node in subgraph_proto['node_def']: # create an NX node id = body_graph.unique_id(pb_node.name) map_original_name[pb_node.name] = id body_graph.add_node(id, pb=pb_node, kind='op') # add incoming edges based on data_nodes_map for dst_port, inp in enumerate(pb_node.input): orig_src_id = inp.split(":")[0] src_id = map_original_name[orig_src_id] src_port = 0 if len(inp.split(":")) == 1 else int( inp.split(":")[-1]) assert (body_graph.has_node(src_id)) edge_attrs = { 'out': src_port, 'in': dst_port, 'name': src_id, 'fw_tensor_debug_info': [(src_id, src_port)], 'in_attrs': ['in', 'name'], 'out_attrs': ['out', 'name'], 'data_attrs': ['fw_tensor_debug_info'] } body_graph.add_edge(src_id, id, **edge_attrs) # create Result nodes in the loop body graph for output in subgraph_proto['output_arg']: output_name = subgraph_proto['ret'][output.name] orig_src_id = output_name.split(":")[0] src_id = map_original_name[orig_src_id] src_port = 0 if len(output_name.split(":")) == 1\ else int(output_name.split(":")[-1]) assert body_graph.has_node( src_id ), 'The body graph does not contain output with name "{}"'.format( src_id) body_results.append( Node(body_graph, add_opoutput(body_graph, src_id, src_port, False)))
def replace_pattern(self, graph: Graph, match: dict): if match['rnn_layer']['op'] == 'LSTM': return rnn_layer = match['rnn_layer'] # Build TensorIterator body first body = Graph(name=rnn_layer.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [ Op._create_data_node( body, rnn_layer.name + '/inport/' + str(inp), { 'shape': rnn_layer.in_node(inp).shape.copy(), 'value': rnn_layer.in_node(inp).value.copy() if rnn_layer.in_node(inp).value is not None and inp in [1, 2] else None }) for inp in [0, 4, 1, 2] ] # X, h_init, WR, B inputs[0].shape[rnn_layer.sequence_dim] = 1 input_squeeze = Squeeze( body, dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0)) input_squeeze_dim = Const( body, dict(name=rnn_layer.name + '/input_squeeze_dim', value=rnn_layer.sequence_dim)).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data( [inputs[0], input_squeeze_dim], edge_attrs=[{ 'internal_port_id': 0 }]) # 2. Output unsqueeze Reshape outputs = [ Op._create_data_node( body, rnn_layer.name + '/outport/' + str(out), { 'shape': rnn_layer.out_node(out).shape.copy() if out in rnn_layer.out_nodes() else None }) for out in [0] ] for out in outputs: add_opoutput(body, out.id, 0, False) outputs[0].shape = np.delete(outputs[0].shape.copy(), rnn_layer.sequence_dim) output_unsqueeze_dim = Const( body, dict(name=rnn_layer.name + '/output_unsqueeze_dim', value=rnn_layer.sequence_dim)).create_node_with_data() output_unsqueeze = Unsqueeze( body, dict(name=rnn_layer.name + '/output_unsqueeze/', internal_layer_id=2)) additional_attrs = dict(activations=rnn_layer.activations, activation_alpha=rnn_layer.activation_alpha, activation_beta=rnn_layer.activation_beta, clip=rnn_layer.clip) if rnn_layer.op == 'GRU': additional_attrs[ 'linear_before_reset'] = rnn_layer.linear_before_reset # 3. ***Cell rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])( body, dict(hidden_size=rnn_layer.hidden_size, name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op), **additional_attrs, internal_layer_id=1)) gru_cell = rnn_cell_op.create_node_with_data(inputs, data_nodes=outputs, edge_attrs=[{}, { 'internal_port_id': 1 }, { 'internal_port_id': 2 }, { 'bin': 'weights' }, { 'bin': 'biases' }]) # internal ports for outputs of cell gru_cell.in_node().out_edge(0)['internal_port_id'] = 4 # h_state gru_cell = output_unsqueeze.create_node_with_data( [gru_cell, output_unsqueeze_dim]) gru_cell.in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, gru_cell.id, 0, False) # 4. TensorIterator layer creating assert rnn_layer.direction in ['forward', 'reverse'] if rnn_layer.direction == 'forward': stride = 1 start = None end = None else: assert rnn_layer.direction == 'reverse' stride = -1 start = -1 end = 0 # stacked h_state output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': rnn_layer.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding last h_state to outputs if len(rnn_layer.out_nodes()) == 2: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }]) ti_op = TensorIterator( graph, { 'name': rnn_layer.name + '/TensorIterator', 'body': body, 'in_ports_count': 4, 'out_ports_count': len(rnn_layer.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': rnn_layer.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, ], 'output_port_map': output_port_map, # only for h state 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, ] }) assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \ "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id) outs = ti_op.create_node_with_data( [rnn_layer.in_node(i) for i in [0, 4]], # X, h_init data_nodes=[ rnn_layer.out_node(i) for i in range(len(rnn_layer.out_nodes())) ], edge_attrs=[{ 'external_port_id': 0 }, { 'external_port_id': 1 }]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(rnn_layer.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id ti = outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)
def find_and_replace_pattern(self, graph: Graph): # Iterate over all data nodes and find all with >= 1 consumers for input_data in list(graph.get_data_nodes()): # We don't use constant data nodes if input_data.value is not None: continue input_shape = np.array(input_data.shape) # Get all unique StridedSlice consumers out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).name == input_data.name] sorted_out_nodes = sorted(out_nodes, key=lambda n: list(n.slices)) out_nodes = unique_by(sorted_out_nodes, strided_slices_equality) if len(out_nodes) <= 1: continue valid_for_replacement = True for node in out_nodes: if len(node.slices) != len(out_nodes[0].slices): valid_for_replacement = False # Detect dimension for splitting split_channel_dim = None for dim_id, s in enumerate(out_nodes[0].slices): l, r, stride = s.start, s.stop, s.step if l != 0 or r != input_shape[dim_id]: if split_channel_dim is None: split_channel_dim = dim_id else: valid_for_replacement = False if split_channel_dim is None: valid_for_replacement = False # split_dims contains tuples with split range and output data node split_dims = [] for out_id, node in enumerate(out_nodes): # Check that StridedSlice op has stride eq 1 and splits only feature channel for id, s in enumerate(node.slices): l, r, stride = s.start, s.stop, s.step # We don't support StridedSlice with stride != 1 if stride != 1: valid_for_replacement = False if id == split_channel_dim: split_dims.append((s.start, s.stop, node.out_node())) if not valid_for_replacement: continue # Check feature split intersection final_data_nodes_list = [] sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1])) # check if we have similar StridedSlice operations with different outputs prev_sd = sorted_split_dims[0] to_remove = [] for i in range(1, len(sorted_split_dims)): if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name: cur_node = sorted_split_dims[i][2] for out in cur_node.out_nodes(): attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0]) graph.remove_edge(cur_node.id, out.id) graph.add_edge(prev_sd[2].id, out.id, **attrs) to_remove.append(i) for ind in reversed(to_remove): sorted_split_dims.pop(ind) size_splits = [] prev_r = 0 for l, r, out in sorted_split_dims: # Split dims shouldn't intersect if l < prev_r: valid_for_replacement = False prev_r = r if prev_r > input_shape[split_channel_dim]: valid_for_replacement = False if not valid_for_replacement: continue prev_r = 0 for l, r, out in sorted_split_dims: # Save missing tensor part if l > prev_r: shape = np.array(input_shape) size_splits.append(l - prev_r) shape[split_channel_dim] = l - prev_r data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape}) add_opoutput(graph, data_node.id, 0, False) final_data_nodes_list.append(data_node) prev_r = r size_splits.append(r - l) final_data_nodes_list.append(out) if prev_r < input_shape[split_channel_dim]: # Add last part of tensor shape = input_shape.copy() shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r size_splits.append(input_shape[split_channel_dim] - prev_r) data_node = Op._create_data_node(graph, 'fake_data_'+out_nodes[0].name, {'shape': shape}) add_opoutput(graph, data_node.id, 0, False) final_data_nodes_list.append(data_node) for node in out_nodes: if not np.all([x == 0 for x in node.shrink_axis_mask]): out_node = node.out_node() if np.any(node['shrink_axis_mask']): self.add_squeeze_for_shrink(graph, node) if np.any(node['new_axis_mask']): self.add_unsqueeze_for_new(graph, node) for i in range(len(final_data_nodes_list)): if final_data_nodes_list[i].name == out_node.name: final_data_nodes_list[i] = node.out_node() break # Insert Split layer and remove old StridedSlice layers # 1. Remove connections from input_data to StridedSlice ops out_data_nodes = [] name_for_future_split = out_nodes[0].name for node in out_nodes: out_data_nodes.append(node.out_node()) graph.remove_edge(input_data.id, node.id) graph.remove_edge(node.id, node.out_node().id) graph.remove_node(node.id) log.debug("Removed: {}".format(node.id)) # 2. Create Split layer and reorder outputs name = name_for_future_split + "/Split" axis_const = Const(graph, {'value': int64_array(split_channel_dim), 'name': name + '/Axis'}).create_node_with_data() size_splits_const = Const(graph, {'value': int64_array(size_splits), 'name': name + '/Sizes'}).create_node_with_data() split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits))) split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const], data_nodes=final_data_nodes_list)
def replace_pattern(self, graph: Graph, match: dict): lstm = match['lstm'] # Build TensorIterator body first body = Graph(name=lstm.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [ Op._create_data_node( body, lstm.name + '/inport/' + str(inp), { 'shape': lstm.in_node(inp).shape.copy(), 'value': lstm.in_node(inp).value.copy() if lstm.in_node(inp).value is not None and inp in [1, 2] else None }) for inp in [0, 4, 5, 1, 2] ] # X, WR, B, h_init, c_init inputs[0].shape[lstm.sequence_dim] = 1 reshape_dim = inputs[0].shape.copy() reshape_dim[lstm.batch_dim] = -1 reshape_dim = np.delete(reshape_dim, lstm.sequence_dim) input_squeeze = Reshape( body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0)) squeeze_dim_data = Const(body, { 'name': lstm.name + '/input_squeeze_dim', 'value': reshape_dim }).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data( [inputs[0], squeeze_dim_data], edge_attrs=[{ 'internal_port_id': 0 }]) # 2. Output unsqueeze Reshape outputs = [ Op._create_data_node( body, lstm.name + '/outport/' + str(out), { 'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes() else lstm.in_node(4).shape.copy() }) for out in [0, 1] ] for out in outputs: add_opoutput(body, out.id, 0, False) unsqueezed_output_shape = outputs[0].shape.copy() unsqueezed_output_shape[lstm.sequence_dim] = 1 squeezed_output_shape = np.delete(unsqueezed_output_shape, lstm.sequence_dim) outputs[0].shape = squeezed_output_shape unsqueezed_output_shape[lstm.batch_dim] = -1 output_unsqueeze = Reshape( body, dict(name=lstm.name + 'output_unsqueeze', internal_layer_id=2)) unsqueeze_dim_data = Const( body, { 'name': lstm.name + '/output_unsqueeze_dim', 'value': unsqueezed_output_shape }).create_node_with_data() # 3. LSTMCell lstm_cell_op = LSTMCell( body, dict(hidden_size=lstm.hidden_size, activations=lstm.activations, activation_alpha=lstm.activation_alpha, activation_beta=lstm.activation_beta, clip=lstm.clip, input_forget=lstm.input_forget, name=lstm.name + '/LSTMCell', internal_layer_id=1)) lstm_cell_node = lstm_cell_op.create_node_with_data( inputs, data_nodes=outputs, edge_attrs=[{}, { 'internal_port_id': 1 }, { 'internal_port_id': 2 }, { 'bin': 'weights' }, { 'bin': 'biases' }]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4 lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5 lstm_cell_node[0] = output_unsqueeze.create_node_with_data( [lstm_cell_node[0], unsqueeze_dim_data]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, lstm_cell_node[0].id, 0, False) # 4. TensorIterator layer creating assert lstm.direction in ['forward', 'reverse'] if lstm.direction == 'forward': stride = 1 start = None end = None else: assert lstm.direction == 'reverse' stride = -1 start = -1 end = 0 output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding h_state, c_state to outputs if len(lstm.out_nodes()) == 3: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }, { 'external_port_id': 5, 'internal_layer_id': 1, 'internal_port_id': 5, }]) ti_op = TensorIterator( graph, { 'name': lstm.name + '/TensorIterator', 'body': body, 'in_ports_count': 3, 'out_ports_count': len(lstm.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, { 'external_port_id': 2, 'internal_layer_id': 1, 'internal_port_id': 2, }, ], 'output_port_map': output_port_map, 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, { 'from_layer': 1, 'from_port': 5, 'to_layer': 1, 'to_port': 2, }, ] }) assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \ "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id) outs = ti_op.create_node_with_data( [lstm.in_node(i) for i in [0, 4, 5]], # X, h_init, c_init data_nodes=[ lstm.out_node(i) for i in range(len(lstm.out_nodes())) ], edge_attrs=[{ 'external_port_id': 0 }, { 'external_port_id': 1 }, { 'external_port_id': 2 }]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(lstm.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id
def rnn_infer(node: Node, out_ports=None): """ General infer function for RNN, GRU, LSTM layers. Assume that 0-port input of node is input data for recurrent layer and node have attrs: hidden_size, """ if out_ports is None: out_ports = [] # 1. Necessary checks (from ONNX specification) assert node.batch_dim <= 1 assert node.sequence_dim <= 1 assert node.batch_dim != node.sequence_dim assert node.direction in ['forward', 'reverse', 'bidirectional'] if node.blobs_wrb: mark_input_bins(node, ['W', 'R', 'B']) else: mark_input_bins(node) # 2. Output shape calculations input_shape = node.in_node(0).shape assert len(input_shape) == 3 # Reshape input nodes for port in [2, 3]: if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \ 'zero_shapes' in node.in_node(port).in_node(): for i in node.in_node(port).in_node().zero_shapes: if node.in_node(port).shape[i] != input_shape[i]: node.in_node(port).value = np.repeat( node.in_node(port).value, input_shape[i], axis=i) node.in_node(port).shape[i] = input_shape[i] out_shape = np.array([ input_shape[node.sequence_dim], input_shape[node.batch_dim], node.hidden_size ], dtype=np.int64) if node.batch_dim == 0: out_shape = np.array([ input_shape[node.batch_dim], input_shape[node.sequence_dim], node.hidden_size ], dtype=np.int64) num_directions = 2 if node.direction in ['bidirectional'] else 1 if node.has_num_directions: if node.format == 'mxnet' and node.normalized is False: # In MXNet RNN layer return output with shape [seq_len, batch_size, hidden_size * num_directions] out_shape[-1] *= num_directions else: # ONNX-like, insert extra dimension to output shape for num_directions out_shape = np.insert(out_shape, 1, np.int64(num_directions)) # 0 output is required creating it if doesn't exist if 0 not in node.out_nodes(): data_node = Op._create_data_node(node.graph, name=node.node + '/ExtraOutput/{}'.format(0), attrs={'executable': True}) if 0 not in node.out_ports(): node.add_output_port(0) node.graph.add_edge(node.id, data_node.id, key=0, out=0) add_opoutput(node.graph, data_node.id, 0, False) node.out_port(0).data.set_shape(out_shape) # 3. Extra outputs for hidden/cell states shape calculations (optional) state_size = np.array([input_shape[node.batch_dim], node.hidden_size], dtype=np.int64) if node.has_num_directions: state_size = np.insert(state_size, 0, num_directions) if node.multilayers: # For multilayer case state sizes from every layer will be concatenated by last axis num_layers = node.num_layers state_size[-1] *= num_layers for i in out_ports: # If node hasn't consumers for hidden/cells state -> create them if i not in node.out_nodes(): data_node = Op._create_data_node(node.graph, name=node.node + '/ExtraOutput/' + str(i), attrs={'executable': True}) if i not in node.out_ports(): node.add_output_port(i) node.graph.add_edge(node.id, data_node.id, key=0, out=i) add_opoutput(node.graph, data_node.id, 0, False) else: data_node = node.out_node(i) data_node.shape = state_size.copy()
def replace_pattern(graph, match: dict): # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op # and make all checks needed for TensorIterator work cond_data = match['condition'].out_node( 0) if not match['condition'].out_port(0).disconnected() else None time_data = match['condition'].out_node(1) if len( match['condition'].out_nodes()) >= 1 else None name = match['condition'].name back_edges = [] inputs = [] outputs = [] if cond_data is not None: for node in cond_data.out_nodes(): if node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorBackEdge': back_edges.append(node.id) elif node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorInput': inputs.append(node.id) elif node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorOutput': outputs.append(node.id) if time_data is not None: for node in time_data.out_nodes(): if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput': inputs.append(node.id) elif node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorOutput': outputs.append(node.id) else: # something goes wrong here assert False condition = match['condition'] tensor_sequence_length = condition.in_node(0) nodes_to_remove = [ n.id for n in (condition, cond_data, time_data, tensor_sequence_length) if n is not None ] graph.remove_nodes_from(nodes_to_remove) body_nodes, extra_inputs = get_body(graph, inputs, outputs) if cond_data is not None: body_nodes = list(set(body_nodes) - set([cond_data])) inputs += extra_inputs assert all([node in graph.nodes() for node in body_nodes]) inputs = [Node(graph, node) for node in inputs] outputs = [Node(graph, node) for node in outputs] back_edges = [Node(graph, node) for node in back_edges] external_inputs = [{ 'external_data_id': node.in_node(1 if node.has_valid('axis') else 0), 'internal_data_id': node.out_node(0), 'axis': node.axis, 'start': node.start, 'end': node.end, 'stride': node.stride, 'part_size': node.part_size } for node in inputs] external_outputs = [{ 'external_data_id': node.out_node(0), 'internal_data_id': node.in_node(1 if node.has_valid('axis') else 0), 'axis': node.axis, 'start': node.start, 'end': node.end, 'stride': node.stride, 'part_size': node.part_size } for node in outputs] back_edges_data = [{ 'from_data_id': node.in_node(1), 'to_data_id': node.out_node(0), 'init_data_id': node.in_node(0), } for node in back_edges] body = Graph(name='body') body.graph = graph.graph body.add_nodes_from([(node, graph.node[node]) for node in body_nodes]) body.add_edges_from([ (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes ]) graph.remove_nodes_from(body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs]) internal_id_count = 0 real_back_edges = [] for edge in back_edges_data: assert edge['from_data_id'].id in body.nodes() assert edge['to_data_id'].id in body.nodes() assert edge['init_data_id'].id in body.nodes() edge['from_data_id'] = Node(body, edge['from_data_id'].id) edge['to_data_id'] = Node(body, edge['to_data_id'].id) edge['init_data_id'] = Node(body, edge['init_data_id'].id) add_opoutput(body, edge['from_data_id'].id, 0, False) # Assign/reuse ids for the back-edge start; it comes from from_data_id assert len(edge['from_data_id'].in_nodes()) == 1 # layer id if not edge['from_data_id'].in_node().has_valid( 'internal_layer_id'): edge['from_data_id'].in_node( )['internal_layer_id'] = internal_id_count internal_id_count += 1 edge['from_layer'] = edge['from_data_id'].in_node( )['internal_layer_id'] # port id if 'internal_port_id' not in edge['from_data_id'].in_edge(): edge['from_data_id'].in_edge( )['internal_port_id'] = internal_id_count internal_id_count += 1 edge['from_port'] = edge['from_data_id'].in_edge( )['internal_port_id'] # Look at all consumers for a data that ends a back-edge # For each such consumer, there will be a separate back-edge (and input) current_real_back_edges = [] for _, consumer, key, edge_attrs in body.out_edges( edge['to_data_id'].id, data=True, keys=True): real_edge = {} real_edge.update( edge) # all real back_edges have the same back-edge start consumer = Node(body, consumer) if real_edge['to_data_id'].in_node().has_valid( 'internal_layer_id'): assert False real_edge['to_data_id'].out_node()['internal_layer_id'] = \ real_edge['to_data_id'].in_node().internal_layer_id elif not consumer.has_valid('internal_layer_id'): consumer['internal_layer_id'] = internal_id_count internal_id_count += 1 real_edge['to_layer'] = consumer['internal_layer_id'] assert 'internal_port_id' not in edge_attrs assert len(real_edge['init_data_id'].out_edges()) == 1 assert not 'internal_port_id' in real_edge[ 'init_data_id'].out_edge() edge_attrs['internal_port_id'] = internal_id_count internal_id_count += 1 real_edge['to_port'] = edge_attrs['internal_port_id'] real_edge['consumer'] = consumer real_edge['consumer_key'] = key real_edge['attrs'] = deepcopy(edge_attrs) current_real_back_edges.append(real_edge) # connect initial data node with each consumer providing actual edge attributes body.add_edges_from([ (real_edge['init_data_id'].id, real_edge['consumer'].id, real_edge['consumer_key'], real_edge['attrs']) for real_edge in current_real_back_edges ]) body.remove_nodes_from( [edge['to_data_id'].id, edge['to_data_id'].in_node().id]) real_back_edges += current_real_back_edges real_external_inputs = [] for ext_inp in external_inputs: assert ext_inp['external_data_id'].id not in body.nodes() assert ext_inp['internal_data_id'].id in body.nodes() ext_inp['internal_data_id'] = Node(body, ext_inp['internal_data_id'].id) if ext_inp['axis'] is not None: # Insert squeezing resize at input port that has partitioning shape = ext_inp['internal_data_id'].shape.copy() assert not ext_inp['internal_data_id'].has_valid('value') new_input_data = Op._create_data_node( body, ext_inp['internal_data_id'].name + '/UnsqueezedInput', dict(shape=shape_insert(shape, ext_inp['axis'], 1))) reshape_op = Squeeze( body, dict(name=ext_inp['internal_data_id'].name + '/InputSqueeze')) reshape_dim_data = Const( body, { 'name': ext_inp['internal_data_id'].name + '/ReshapeDim', 'value': ext_inp['axis'] }).create_node_with_data() reshape_op.create_node_with_data( [new_input_data, reshape_dim_data], data_nodes=[ext_inp['internal_data_id']]) ext_inp['internal_data_id'] = new_input_data ext_inp['internal_data_id']['is_input'] = True assert len(ext_inp['internal_data_id'].in_nodes()) == 0 ext_inp['external_port_id'] = internal_id_count internal_id_count += 1 for _, consumer, edge_attrs in body.out_edges( ext_inp['internal_data_id'].id, data=True): real_ext_inp = {} real_ext_inp.update(ext_inp) consumer = Node(body, consumer) if not consumer.has_valid('internal_layer_id'): consumer['internal_layer_id'] = internal_id_count internal_id_count += 1 if not 'internal_port_id' in edge_attrs: edge_attrs['internal_port_id'] = internal_id_count internal_id_count += 1 real_ext_inp['internal_layer_id'] = consumer[ 'internal_layer_id'] real_ext_inp['internal_port_id'] = edge_attrs[ 'internal_port_id'] real_external_inputs.append(real_ext_inp) for ext_out in external_outputs: assert ext_out['external_data_id'].id not in body.nodes() assert ext_out['internal_data_id'].id in body.nodes() ext_out['internal_data_id'] = Node(body, ext_out['internal_data_id'].id) if ext_out['axis'] is not None: # Insert unsqueezing resize at output port that has partitioning reshape_op = Unsqueeze( body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze')) reshape_dim_data = Const( body, { 'name': ext_out['internal_data_id'].name + '/ReshapeDim', 'value': ext_out['axis'] }).create_node_with_data() ext_out['internal_data_id'] = reshape_op.create_node_with_data( [ext_out['internal_data_id'], reshape_dim_data]) # TODO: add here working with simple outputs if not any([ out_node.soft_get('op', None) == 'Result' for out_node in ext_out['internal_data_id'].out_nodes() ]): add_opoutput(body, ext_out['internal_data_id'].id, 0, False) # assert len(ext_out['internal_data_id'].out_nodes()) == 0 assert len(ext_out['internal_data_id'].in_nodes()) == 1 if not 'internal_layer_id' in ext_out['internal_data_id'].in_node( ): ext_out['internal_data_id'].in_node( )['internal_layer_id'] = internal_id_count internal_id_count += 1 if not 'internal_port_id' in ext_out['internal_data_id'].in_edge(): ext_out['internal_data_id'].in_edge( )['internal_port_id'] = internal_id_count internal_id_count += 1 ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node( )['internal_layer_id'] ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge( )['internal_port_id'] ext_out['external_port_id'] = internal_id_count internal_id_count += 1 # create TensorIterator layer with pre-computed components ti_op = TensorIterator( graph, { 'name': name + '/TensorIterator', 'body': body, 'in_ports_count': len(external_inputs), 'out_ports_count': len(external_outputs), 'input_port_map': [{ field: external_input[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end' ] } for external_input in real_external_inputs], 'output_port_map': [{ field: external_output[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end' ] } for external_output in external_outputs], 'back_edges': [{ field: edge[field] for field in ['from_layer', 'from_port', 'to_layer', 'to_port'] } for edge in real_back_edges], }) ti_outs = ti_op.create_node_with_data( inputs=[inp['external_data_id'] for inp in external_inputs], edge_attrs=[{ 'external_port_id': inp['external_port_id'] } for inp in external_inputs], data_nodes=[out['external_data_id'] for out in external_outputs]) if not isinstance(ti_outs, list): ti_outs = [ti_outs] for i, out in enumerate(ti_outs): out.in_edge( )['external_port_id'] = external_outputs[i]['external_port_id'] ti = ti_outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)
def extract(cls, loop_node): Loop.update_node_stat(loop_node, {}) body_graph_proto = onnx_attr(loop_node, 'body', 'g', None) main_graph = loop_node.graph # create a Graph object for the body and take graph attributes from the main graph body_graph = Graph() main_graph_attrs_copy = copy.deepcopy(main_graph.graph) del main_graph_attrs_copy['tensor_mapping'] body_graph.graph.update(main_graph_attrs_copy) loop_node['body'] = body_graph # maps a tensor name to a node produced it and the node port: str -> (node_id, node_port) data_nodes_map = {} body_graph.graph[ 'tensor_mapping'] = data_nodes_map # save mapping for possible Loop inside the Loop body_parameters = add_initializers_and_inputs_to_graph( body_graph, body_graph_proto, data_nodes_map) external_edges = [ ] # (src_node, src_out_port), dest_body_parameter_node additional_params = { } # (src_node, src_out_port) -> parameter_node (for manually added Parameters) # Go through all nodes in the original model order because data nodes are defined on-the-fly and order matters for pb_node in body_graph_proto.node: # create an NX node id = body_graph.unique_id(node_id(pb_node)) body_graph.add_node(id, pb=pb_node, kind='op') # add incoming edges based on data_nodes_map for dst_port, inp in enumerate(pb_node.input): # should add edge inp --> id if inp not in data_nodes_map: if inp == '': # input is omitted; most likely it corresponds to an optional input for an operator continue elif inp in main_graph.graph['tensor_mapping']: log.debug( 'The edge between outer and inner graphs detected: {} -> {}' .format(inp, id)) if main_graph.graph['tensor_mapping'][ inp] not in additional_params: # create new Parameter body node and connect the body node with the outer graph using it param_id = str(inp) body_graph.add_node(param_id, kind='op', op='Parameter', name=param_id, pb=None, shape=None) parameter_node = Node(body_graph, param_id) # need to manually update necessary attrs for the node because extractor will not be called # for it because the node does not have .pb attribute Parameter.update_node_stat(parameter_node, {}) external_edges.append( (main_graph.graph['tensor_mapping'][inp], parameter_node)) src_id, src_port = param_id, 0 additional_params[main_graph.graph[ 'tensor_mapping'][inp]] = parameter_node else: src_id, src_port = additional_params[ main_graph.graph['tensor_mapping'][inp]].id, 0 else: raise Error( 'Reference to "{}" is not satisfied. A node refer not existing data tensor. ONNX ' 'model is not consistent. Protobuf fragment: {}', inp, pb_node) else: src_id, src_port = data_nodes_map[inp] assert (body_graph.has_node(src_id)) edge_attrs = { 'out': src_port, 'in': dst_port, 'name': inp, 'fw_tensor_debug_info': [(inp, inp)], 'in_attrs': ['in', 'name'], 'out_attrs': ['out', 'name'], 'data_attrs': ['fw_tensor_debug_info'] } body_graph.add_edge(src_id, id, **edge_attrs) # add outgoing edges to data_nodes_map for src_port, out in enumerate(pb_node.output): if out in data_nodes_map: log.debug("Detected reuse of blob {}.".format(out)) data_nodes_map[out] = (id, src_port) body_results = [] for output in body_graph_proto.output: tensor_name = str(output.name) node_name, output_port = data_nodes_map[tensor_name] assert body_graph.has_node( node_name ), 'The body graph does not contain output with name "{}"'.format( node_name) body_results.append( Node(body_graph, add_opoutput(body_graph, node_name, output_port, False))) # add 'internal_layer_id' attribute which is a must have attribute for the loop body node for idx, body_node in enumerate(body_graph.get_op_nodes()): body_node['internal_layer_id'] = idx loop_carried_dependencies_count = len(body_graph_proto.input) - 2 scan_outputs_count = len( body_graph_proto.output) - 1 - loop_carried_dependencies_count # Loop inputs: # 0 - trip count # 1 - execution condition # 2 .. - loop carried dependencies # Loop outputs: # 0 .. loop_carried_dependencies_count - 1 - loop carried dependencies # loop_carried_dependencies_count .. - scan outputs # Body inputs: # 0 - iteration number # 1 - execution condition # 2 .. - loop carried dependencies # Body outputs: # 0 - execution condition # 1 .. loop_carried_dependencies_count - loop carried dependencies # loop_carried_dependencies_count + 1 .. - scan outputs body_graph.stage = 'front' # some of the inputs/outputs may not be connected but the normalization transformation will take care of it # connection Loop body nodes with external input edges next_loop_input_port_idx = sorted(loop_node.in_edges().keys())[-1] + 1 for (src_node, src_port), body_node in external_edges: main_graph.add_edge( src_node, loop_node.id, **{ 'out': src_port, 'in': next_loop_input_port_idx, 'name': src_node, 'fw_tensor_debug_info': [(src_node, src_node)], 'in_attrs': ['in', 'name'], 'out_attrs': ['out', 'name'], 'data_attrs': ['fw_tensor_debug_info'] }) connect_body_input(loop_node, next_loop_input_port_idx, body_node) next_loop_input_port_idx += 1 # mark current iteration input Parameter node Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0]) # connect initial value for "execution condition" input of the loop connect_body_input(loop_node, 1, body_parameters[1]) # add back edge with "execution condition" Loop.add_back_edge(loop_node, body_parameters[1], body_results[0]) # mark "execution condition" Result node Loop.mark_execution_condition_result_node(loop_node, body_results[0]) # connect initial value for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): connect_body_input(loop_node, idx + 2, body_parameters[idx + 2]) # add back edge for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): Loop.add_back_edge(loop_node, body_parameters[idx + 2], body_results[idx + 1]) # connect final value for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): connect_body_output(loop_node, idx, body_results[idx + 1]) # connect "scan outputs" and mark axis for concatenation for idx in range(loop_carried_dependencies_count, loop_carried_dependencies_count + scan_outputs_count): connect_body_output(loop_node, idx, body_results[idx + 1], axis=0) # run function to parse body nodes attributes similar to the main graph extract_node_attrs( body_graph, lambda node: onnx_op_extractor( node, check_for_duplicates(onnx_op_extractors))) return cls.enabled
def extract(cls, loop_node): Loop.update_node_stat(loop_node, {}) body_graph_proto = onnx_attr(loop_node, 'body', 'g', None) main_graph = loop_node.graph # create a Graph object for the body and take graph attributes from the main graph body_graph = Graph() main_graph_attrs_copy = {} for attr_key, attr_value in main_graph.graph.items(): if attr_key not in ['tensor_mapping', 'parent_node']: main_graph_attrs_copy[attr_key] = copy.deepcopy(attr_value) body_graph.graph.update(main_graph_attrs_copy) loop_node['body'] = body_graph # save parent node for nested loops to know which node contains body (and which graph is on upper level) body_graph.graph['parent_node'] = loop_node # maps a tensor name to a node produced it and the node port: str -> (node_id, node_port) data_nodes_map = {} body_graph.graph['tensor_mapping'] = data_nodes_map # save mapping for possible Loop inside the Loop body_parameters = add_initializers_and_inputs_to_graph(body_graph, body_graph_proto, data_nodes_map) external_edges = [] # (src_node, src_out_port), dest_body_parameter_node # save additional edges information for graph on each level, the first one is the deepest additional_params = [] # (src_node, src_out_port) -> parameter_node (for manually added Parameters) # Go through all nodes in the original model order because data nodes are defined on-the-fly and order matters for pb_node in body_graph_proto.node: # create an NX node id = body_graph.unique_id(node_id(pb_node)) body_graph.add_node(id, pb=pb_node, kind='op') if hasattr(body_graph, 'op_names_statistic') and hasattr(pb_node, 'op_type'): body_graph.op_names_statistic[pb_node.op_type] += 1 # add incoming edges based on data_nodes_map for dst_port, inp in enumerate(pb_node.input): # should add edge src_internal_id --> dst_id if inp not in data_nodes_map: if inp == '': # input is omitted; most likely it corresponds to an optional input for an operator continue else: is_finished = create_cross_body_edge(body_graph, external_edges, additional_params, inp, id, dst_port) if not is_finished: raise Error( 'Reference to "{}" is not satisfied. A node refer not existing data tensor. ONNX ' 'model is not consistent. Protobuf fragment: {}', inp, pb_node) else: src_id, src_port = data_nodes_map[inp] create_edge_with_attrs(body_graph, inp, src_id, src_port, id, dst_port) # add outgoing edges to data_nodes_map for src_port, out in enumerate(pb_node.output): if out in data_nodes_map: log.debug("Detected reuse of blob {}.".format(out)) data_nodes_map[out] = (id, src_port) body_results = [] for output in body_graph_proto.output: tensor_name = str(output.name) node_name, output_port = data_nodes_map[tensor_name] assert body_graph.has_node(node_name), 'The body graph does not contain output with name "{}"'.format( node_name) body_results.append(Node(body_graph, add_opoutput(body_graph, node_name, output_port, False))) # add 'internal_layer_id' attribute which is a must have attribute for the loop body node for idx, body_node in enumerate(body_graph.get_op_nodes()): body_node['internal_layer_id'] = idx loop_carried_dependencies_count = len(body_graph_proto.input) - 2 scan_outputs_count = len(body_graph_proto.output) - 1 - loop_carried_dependencies_count # Loop inputs: # 0 - trip count # 1 - execution condition # 2 .. - loop carried dependencies # Loop outputs: # 0 .. loop_carried_dependencies_count - 1 - loop carried dependencies # loop_carried_dependencies_count .. - scan outputs # Body inputs: # 0 - iteration number # 1 - execution condition # 2 .. - loop carried dependencies # Body outputs: # 0 - execution condition # 1 .. loop_carried_dependencies_count - loop carried dependencies # loop_carried_dependencies_count + 1 .. - scan outputs # some of the inputs/outputs may not be connected but the normalization transformation will take care of it # connection Loop body nodes with external input edges next_loop_input_port_idx = sorted(loop_node.in_edges().keys())[-1] + 1 cur_graph = body_graph for external_edges_subg in external_edges: if 'parent_node' not in cur_graph.graph: continue cur_loop_node = cur_graph.graph['parent_node'] parent_graph = cur_loop_node.graph for (src_node, src_port), body_node, tensor_name in external_edges_subg: create_edge_with_attrs(parent_graph, tensor_name, src_node, src_port, cur_loop_node.id, next_loop_input_port_idx) Loop.connect_body_input(cur_loop_node, next_loop_input_port_idx, body_node) next_loop_input_port_idx += 1 cur_graph = parent_graph # mark current iteration input Parameter node Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0]) # connect initial value for "execution condition" input of the loop Loop.connect_body_input(loop_node, 1, body_parameters[1]) # add back edge with "execution condition" Loop.add_back_edge(loop_node, body_parameters[1], body_results[0]) # mark "execution condition" Result node Loop.mark_execution_condition_result_node(loop_node, body_results[0]) # connect initial value for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): Loop.connect_body_input(loop_node, idx + 2, body_parameters[idx + 2]) # add back edge for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): Loop.add_back_edge(loop_node, body_parameters[idx + 2], body_results[idx + 1]) # connect final value for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): Loop.connect_body_output(loop_node, idx, body_results[idx + 1]) # connect "scan outputs" and mark axis for concatenation for idx in range(loop_carried_dependencies_count, loop_carried_dependencies_count + scan_outputs_count): Loop.connect_body_output(loop_node, idx, body_results[idx + 1], axis=0) # run function to parse body nodes attributes similar to the main graph extract_node_attrs(body_graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors))) return cls.enabled