def normalize_body_graph(loop_node: Node): loop_name = loop_node.soft_get('name', loop_node.id) # connect "trip count" input if it is not connected with default value "Infinity" (-1) if not loop_node.is_in_port_connected(0): loop_node.add_input_port(0, skip_if_exist=True) Const(loop_node.graph, {'name': loop_name + '/trip_count', 'value': int64_array(-1)}).\ create_node().out_port(0).connect(loop_node.in_port(0)) # connect "execution condition" input if it is not connected with default value True if not loop_node.is_in_port_connected(1): loop_node.add_input_port(1, skip_if_exist=True) Const(loop_node.graph, {'name': loop_name + '/execution_cond', 'value': np.array(True, dtype=np.bool)}).\ create_node().out_port(0).connect(loop_node.in_port(1)) # scan output need Unsqueeze over axis 0 for record in loop_node.output_port_map: body_node = Loop.get_body_node_by_internal_id( loop_node, record['internal_layer_id']) assert body_node is not None assert body_node.soft_get('type') == 'Result' if record['axis'] is not None: unsqueeze = create_op_with_const_inputs( loop_node.body, Unsqueeze, {1: int64_array([0])}) body_node.in_port(0).get_connection().insert_node(unsqueeze) Loop.normalize_input_output_ports(loop_node)
def find_and_replace_pattern(self, graph: Graph): for loop_node in graph.get_op_nodes(op='Loop'): loop_name = loop_node.soft_get('name', loop_node.id) body_graph = loop_node['body'] body_pattern = MapFNOutputConcatenation.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) for internal_match in internal_matches: # check if TensorListReserve from the main graph is connected with Parameter node from the body graph # that is assigned for storing intermediate output results of While Loop. If yes, the transformation # detects intermediate outputs concatentation by this port and can use Loop axis attribute reserve_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['container'].internal_layer_id) reserve_node = reserve_node[0] if ( len(reserve_node) == 1 and reserve_node[0].op == 'TensorListReserve') else None if reserve_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue stack_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['concatenation_result'].internal_layer_id) stack_node = stack_node[0] if len(stack_node) == 1 else None if stack_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue # skip StopGradient node if it exists between While loop output port and TensorListStack operation stack_node = skip_nodes_by_condition( stack_node, lambda x: x.has_and_set('identity'), True) stack_node = stack_node if stack_node.op == 'TensorListStack' else None if stack_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for intermediate outputs concatenation" .format(loop_name)) continue external_match = { 'while': loop_node, 'reserve': reserve_node, 'stack': stack_node } # check that back edges connect Parameter node (or container with intermediate output results) # and concatenation result produced by TensorListSetItem node if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, internal_match['container'].internal_layer_id) and \ Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id): MapFNOutputConcatenation.transform_map_fn_output_concatenation( external_match, internal_match)
def replace_sub_graph(self, graph: Graph, external_match: dict): loop_node = external_match['while'] body_graph = loop_node['body'] body_pattern = KerasRNNOutputConcatenation.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) if len(internal_matches) == 1: internal_match = internal_matches[0] reserve_port_idx = compute_input_port_idx( external_match['reserve'], loop_node) stack_port_idx = external_match['stack'].in_port( 0).get_source().idx # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node # check connections between body output ports and external output ports of Loop node if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, internal_match['container'].internal_layer_id) and \ Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx, internal_match['container'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx, internal_match['concatenation_result'].internal_layer_id): KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation( external_match, internal_match)
def find_and_replace_pattern(self, graph: Graph): for loop_node in graph.get_op_nodes(op='Loop'): loop_name = loop_node.soft_get('name', loop_node.id) body_graph = loop_node['body'] body_pattern = MapFNInputSlicing.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) for internal_match in internal_matches: # check if TensorListGetItem from the body graph is connected with TensorListFromTensor # from the main graph. If yes, the transformation detects input slicing by this port # and can use Loop axis attribute unstack_node = Loop.get_external_nodes_by_internal_id( loop_node, internal_match['tensor_list'].internal_layer_id) unstack_node = unstack_node[0] if ( len(unstack_node) == 1 and unstack_node[0].op == 'TensorListFromTensor') else None if unstack_node is None: log.info( "A sub-graph around the loop node {} does not match " "TensorFlow 2 MapFN pattern for input slicing".format( loop_name)) continue external_match = {'while': loop_node, 'unstack': unstack_node} # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node if Loop.back_edge_exists( loop_node.back_edges, internal_match['increment_iteration_result']. internal_layer_id, internal_match['current_iteration'].internal_layer_id): MapFNInputSlicing.transform_map_fn_input_slicing( external_match, internal_match)
def transform_keras_rnn_output_concatenation(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node :param external_match: a match used for handling a part of the main graph responsible for output concatenation :param internal_match: a match used for handling a part of the body graph responsible for output concatenation """ loop_node = external_match['while'] stack_node = external_match['stack'] list_reserve_node = external_match['reserve'] body_graph = loop_node['body'] tensor_list_set_item_node = internal_match['concatenation'] tensor_list_set_item_node_name = tensor_list_set_item_node.soft_get( 'name', tensor_list_set_item_node.id) list_result_node = internal_match['concatenation_result'] # replace TensorListSetItem with Unsqueeze and use axis attribute for corresponding Result node # to concatenate results from different iterations unsqueeze_list_element = create_op_with_const_inputs( body_graph, Unsqueeze, {1: int64_array(0)}, {'name': 'TensorListSetItemUnsqueeze'}) tensor_list_set_item_node.in_port(2).get_connection().set_destination( unsqueeze_list_element.in_port(0)) tensor_list_set_item_node.out_port(0).get_connection().set_source( unsqueeze_list_element.out_port(0)) rename_nodes([(tensor_list_set_item_node, tensor_list_set_item_node_name + '/AbandonedName'), (unsqueeze_list_element, tensor_list_set_item_node_name) ]) list_result_node_layer_id = list_result_node.internal_layer_id Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id, 'axis', 0) # remove TensorListStack to by-pass the node since the result from the Loop node is already concatenated stack_node.out_port(0).get_connection().set_source( stack_node.in_port(0).get_connection().get_source()) # disconnect ListReserve node because it is no longer needed for Loop list_reserve_node.out_port(0).disconnect() # connect a number of iterations with trip count that can be received from the second input of ListReserve # create a constant network with True value for execution_condition so that IE can ignore execution condition # and perform trip_counts iterations. This approach with known trip count value allows to avoid dynamism. loop_node.in_port(1).disconnect() list_reserve_node.in_port(1).get_source().connect(loop_node.in_port(1)) for record in loop_node.output_port_map: if 'purpose' in record and record[ 'purpose'] == 'execution_condition': exec_cond_layer_id = record['internal_layer_id'] exec_cond_node = Loop.get_body_node_by_internal_id( loop_node, exec_cond_layer_id) const_true = Const(body_graph, { 'value': np.array(True, dtype=np.bool) }).create_node() exec_cond_node.in_port(0).get_connection().set_source( const_true.out_port(0))
def normalize_loop_node(graph: Graph, loop_node: Node): loop_name = loop_node.soft_get('name', loop_node.id) # disconnect current iteration from external port #0 and move trip count to this port loop_node.in_port(0).disconnect() loop_node.in_port(1).get_connection().add_destination(loop_node.in_port(0)) Loop.update_port_map_value(loop_node.input_port_map, 'external_port_id', 1, 0) # connect execution condition port exec_cond_node = Const(graph, {'name': loop_name + '/ExecutionConditionValue', 'value': np.array(True, dtype=np.bool)}).create_node() loop_node.in_port(1).get_connection().set_source(exec_cond_node.out_port(0)) loop_node.body.clean_up() Loop.normalize_input_output_ports(loop_node)
def loop_infer(step_node, port_num): out_port_map = step_node.output_port_map int_layer_id = None iterations_count = Loop.iterations_count(step_node) for record in out_port_map: if record['external_port_id'] == port_num: int_layer_id = record['internal_layer_id'] ti_set_output_port_shape(step_node, int_layer_id, port_num, iterations_count, 0)
def find_and_replace_pattern(self, graph: Graph): cleanup_called_once = False # walk through all Loop nodes and find Const inputs for loop_node in graph.get_op_nodes(op='Loop'): # call clean-up only once that performs constant folding if not cleanup_called_once: graph.clean_up() cleanup_called_once = True # move constant node into the body graph and removes body parameter nodes corresponding to them Loop.pull_constant_inputs_into_body(loop_node) # since some input ports can be removed after the pulling constants, normalization of Loop node is required Loop.normalize_input_output_ports(loop_node) # perform shape inference for the Loop node again since new constant can be appeared # and constant folding can be helpful for weights path to Convolution node inside the body graph loop_node['need_shape_inference'] = True
def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node :param external_match: a match used for handling a part of the main graph responsible for input slicing :param internal_match: a match used for handling a part of the body graph responsible for input slicing """ loop_node = external_match['while'] unstack_node = external_match['unstack'] body_graph = loop_node['body'] tensor_list_get_item_node = internal_match['slicing'] unstack_placeholder = internal_match['tensor_list'] tensor_list_get_item_node_name = tensor_list_get_item_node.soft_get( 'name', tensor_list_get_item_node.id) # 1. process the body graph to avoid unsupported operations: TensorListGetItem and TensorListSetItem # replace TensorListGetItem with Squeeze node and iterate through slices using axis for input port squeeze_list_element = create_op_with_const_inputs( body_graph, Squeeze, {1: int64_array(0)}, {'name': 'TensorListGetItemSqueeze'}) tensor_list_get_item_node.in_port(0).get_connection().set_destination( squeeze_list_element.in_port(0)) tensor_list_get_item_node.out_port(0).get_connection().set_source( squeeze_list_element.out_port(0)) rename_nodes([(tensor_list_get_item_node, tensor_list_get_item_node_name + '/AbandonedName'), (squeeze_list_element, tensor_list_get_item_node_name)]) unstack_placeholder_layer_id = unstack_placeholder.internal_layer_id Loop.update_port_map_value_ext(loop_node.input_port_map, 'internal_layer_id', unstack_placeholder_layer_id, 'axis', 0) # 2. process locality of Loop node in the main graph to avoid unsupported operations: # TensorListFromTensor, TensorListReserve, and TensorListStack # remove TensorListFromTensor and pass a tensor to Loop as is unstack_node.out_port(0).get_connection().set_source( unstack_node.in_port(0).get_connection().get_source())
def replace_sub_graph(self, graph: Graph, external_match: dict): loop_node = external_match['while'] body_graph = loop_node['body'] body_pattern = KerasRNNInputSlicing.get_body_pattern() internal_matches = find_subgraph_match_to_pattern( body_graph, body_pattern) # a case of multiple matches is not handled since it is not clear how to select corresponding match if len(internal_matches) == 1: internal_match = internal_matches[0] loop_node = external_match['while'] unstack_port_idx = compute_input_port_idx( external_match['unstack'], loop_node) # check that back edges connect correct Parameter and Result nodes in the body # check connections between body input ports and external inputs ports of Loop node if Loop.back_edge_exists(loop_node.back_edges, internal_match['increment_iteration_result'].internal_layer_id, internal_match['current_iteration'].internal_layer_id) and \ Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx, internal_match['tensor_list'].internal_layer_id): # only if inter-graph match passed it starts to process the sub-graph KerasRNNInputSlicing.transform_keras_rnn_input_slicing( external_match, internal_match)
def extract(cls, loop_node): Loop.update_node_stat(loop_node, {}) loop_name = loop_node.soft_get('name', loop_node.id) # check that required body and condition functions exist in the graph library main_graph = loop_node.graph body_graph_name = loop_node.pb.attr['body'].func.name cond_graph_name = loop_node.pb.attr['cond'].func.name assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \ 'by node with name "{}".'.format(loop_name) library_graph = main_graph.graph['library'] assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \ 'that is required by node ' \ 'with name "{}".'.format(body_graph_name, loop_name) body_graph_proto = library_graph[body_graph_name] assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \ 'that is required by node ' \ 'with name "{}".'.format(cond_graph_name, loop_name) cond_graph_proto = library_graph[cond_graph_name] body_graph = Graph() # fill the body graph for attr_key in main_graph.graph.keys(): if attr_key != 'library': body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key]) else: # it is sufficient to have a link to the library body_graph.graph['library'] = main_graph.graph['library'] loop_node['body'] = body_graph # create Parameter nodes for the body graph body_parameters = [] body_parameter_names = [] for idx, pb_node in enumerate(body_graph_proto['input_arg']): param_id = body_graph.unique_id(pb_node.name) body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None) parameter_node = Node(body_graph, pb_node.name) Parameter.update_node_stat(parameter_node, {'data_type': tf_dtype_extractor(pb_node.type), 'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])} ) body_parameters.append(parameter_node) body_parameter_names.append(param_id) # update the loop body graph with the body function graph body_results = [] update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results) # update the loop body graph with the condition function graph update_body_graph(body_graph, cond_graph_proto, body_parameter_names, body_results) # add 'internal_layer_id' attribute which is a must have attribute for the loop body node for idx, body_node in enumerate(body_graph.get_op_nodes()): body_node['internal_layer_id'] = idx body_graph.stage = 'front' # Currently, # Loop Inputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Body Inputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Body Outputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Loop Outputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # so inputs must be reordered and execution condition must be created in the front transformation # to be aligned with the specification # connect external input ports with body parameter nodes except current iteration # since it must be disconnected from external port for idx in range(1, len(body_parameters)): Loop.connect_body_input(loop_node, idx, body_parameters[idx]) # mark current iteration input Parameter node and execution condition Result node Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0]) Loop.mark_execution_condition_result_node(loop_node, body_results[-1]) # connect back edges in the body except current iteration for idx in range(1, len(body_parameters)): Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx]) # connect body outputs with Loop operation output ports except the execution condition result for idx in range(len(body_results)-1): Loop.connect_body_output(loop_node, idx, body_results[idx]) # run function to parse body nodes attributes similar to the main graph extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors))) return cls.enabled
def extract(cls, loop_node): Loop.update_node_stat(loop_node, {}) body_graph_proto = onnx_attr(loop_node, 'body', 'g', None) main_graph = loop_node.graph # create a Graph object for the body and take graph attributes from the main graph body_graph = Graph() main_graph_attrs_copy = copy.deepcopy(main_graph.graph) del main_graph_attrs_copy['tensor_mapping'] body_graph.graph.update(main_graph_attrs_copy) loop_node['body'] = body_graph # maps a tensor name to a node produced it and the node port: str -> (node_id, node_port) data_nodes_map = {} body_graph.graph[ 'tensor_mapping'] = data_nodes_map # save mapping for possible Loop inside the Loop body_parameters = add_initializers_and_inputs_to_graph( body_graph, body_graph_proto, data_nodes_map) external_edges = [ ] # (src_node, src_out_port), dest_body_parameter_node additional_params = { } # (src_node, src_out_port) -> parameter_node (for manually added Parameters) # Go through all nodes in the original model order because data nodes are defined on-the-fly and order matters for pb_node in body_graph_proto.node: # create an NX node id = body_graph.unique_id(node_id(pb_node)) body_graph.add_node(id, pb=pb_node, kind='op') # add incoming edges based on data_nodes_map for dst_port, inp in enumerate(pb_node.input): # should add edge inp --> id if inp not in data_nodes_map: if inp == '': # input is omitted; most likely it corresponds to an optional input for an operator continue elif inp in main_graph.graph['tensor_mapping']: log.debug( 'The edge between outer and inner graphs detected: {} -> {}' .format(inp, id)) if main_graph.graph['tensor_mapping'][ inp] not in additional_params: # create new Parameter body node and connect the body node with the outer graph using it param_id = str(inp) body_graph.add_node(param_id, kind='op', op='Parameter', name=param_id, pb=None, shape=None) parameter_node = Node(body_graph, param_id) # need to manually update necessary attrs for the node because extractor will not be called # for it because the node does not have .pb attribute Parameter.update_node_stat(parameter_node, {}) external_edges.append( (main_graph.graph['tensor_mapping'][inp], parameter_node)) src_id, src_port = param_id, 0 additional_params[main_graph.graph[ 'tensor_mapping'][inp]] = parameter_node else: src_id, src_port = additional_params[ main_graph.graph['tensor_mapping'][inp]].id, 0 else: raise Error( 'Reference to "{}" is not satisfied. A node refer not existing data tensor. ONNX ' 'model is not consistent. Protobuf fragment: {}', inp, pb_node) else: src_id, src_port = data_nodes_map[inp] assert (body_graph.has_node(src_id)) edge_attrs = { 'out': src_port, 'in': dst_port, 'name': inp, 'fw_tensor_debug_info': [(inp, inp)], 'in_attrs': ['in', 'name'], 'out_attrs': ['out', 'name'], 'data_attrs': ['fw_tensor_debug_info'] } body_graph.add_edge(src_id, id, **edge_attrs) # add outgoing edges to data_nodes_map for src_port, out in enumerate(pb_node.output): if out in data_nodes_map: log.debug("Detected reuse of blob {}.".format(out)) data_nodes_map[out] = (id, src_port) body_results = [] for output in body_graph_proto.output: tensor_name = str(output.name) node_name, output_port = data_nodes_map[tensor_name] assert body_graph.has_node( node_name ), 'The body graph does not contain output with name "{}"'.format( node_name) body_results.append( Node(body_graph, add_opoutput(body_graph, node_name, output_port, False))) # add 'internal_layer_id' attribute which is a must have attribute for the loop body node for idx, body_node in enumerate(body_graph.get_op_nodes()): body_node['internal_layer_id'] = idx loop_carried_dependencies_count = len(body_graph_proto.input) - 2 scan_outputs_count = len( body_graph_proto.output) - 1 - loop_carried_dependencies_count # Loop inputs: # 0 - trip count # 1 - execution condition # 2 .. - loop carried dependencies # Loop outputs: # 0 .. loop_carried_dependencies_count - 1 - loop carried dependencies # loop_carried_dependencies_count .. - scan outputs # Body inputs: # 0 - iteration number # 1 - execution condition # 2 .. - loop carried dependencies # Body outputs: # 0 - execution condition # 1 .. loop_carried_dependencies_count - loop carried dependencies # loop_carried_dependencies_count + 1 .. - scan outputs body_graph.stage = 'front' # some of the inputs/outputs may not be connected but the normalization transformation will take care of it # connection Loop body nodes with external input edges next_loop_input_port_idx = sorted(loop_node.in_edges().keys())[-1] + 1 for (src_node, src_port), body_node in external_edges: main_graph.add_edge( src_node, loop_node.id, **{ 'out': src_port, 'in': next_loop_input_port_idx, 'name': src_node, 'fw_tensor_debug_info': [(src_node, src_node)], 'in_attrs': ['in', 'name'], 'out_attrs': ['out', 'name'], 'data_attrs': ['fw_tensor_debug_info'] }) connect_body_input(loop_node, next_loop_input_port_idx, body_node) next_loop_input_port_idx += 1 # mark current iteration input Parameter node Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0]) # connect initial value for "execution condition" input of the loop connect_body_input(loop_node, 1, body_parameters[1]) # add back edge with "execution condition" Loop.add_back_edge(loop_node, body_parameters[1], body_results[0]) # mark "execution condition" Result node Loop.mark_execution_condition_result_node(loop_node, body_results[0]) # connect initial value for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): connect_body_input(loop_node, idx + 2, body_parameters[idx + 2]) # add back edge for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): Loop.add_back_edge(loop_node, body_parameters[idx + 2], body_results[idx + 1]) # connect final value for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): connect_body_output(loop_node, idx, body_results[idx + 1]) # connect "scan outputs" and mark axis for concatenation for idx in range(loop_carried_dependencies_count, loop_carried_dependencies_count + scan_outputs_count): connect_body_output(loop_node, idx, body_results[idx + 1], axis=0) # run function to parse body nodes attributes similar to the main graph extract_node_attrs( body_graph, lambda node: onnx_op_extractor( node, check_for_duplicates(onnx_op_extractors))) return cls.enabled
def extract(cls, loop_node): Loop.update_node_stat(loop_node, {}) # check that required body and condition functions exist in the graph library main_graph = loop_node.graph body_graph_proto = get_graph_proto(main_graph, 'body', loop_node) cond_graph_proto = get_graph_proto(main_graph, 'cond', loop_node) body_graph = create_internal_graph(main_graph) loop_node['body'] = body_graph # create Parameter nodes for the body graph body_parameters, body_parameter_names = convert_graph_inputs_to_parameters( body_graph, body_graph_proto) # update the loop body graph with the body function graph body_results = [] update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results) # update the loop body graph with the condition function graph update_body_graph(body_graph, cond_graph_proto, body_parameter_names, body_results) # add 'internal_layer_id' attribute which is a must have attribute for the loop body node for idx, body_node in enumerate(body_graph.get_op_nodes()): body_node['internal_layer_id'] = idx body_graph.stage = 'front' # Currently, # Loop Inputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Body Inputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Body Outputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Loop Outputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # so inputs must be reordered and execution condition must be created in the front transformation # to be aligned with the specification # connect external input ports with body parameter nodes except current iteration # since it must be disconnected from external port for idx in range(1, len(body_parameters)): Loop.connect_body_input(loop_node, idx, body_parameters[idx]) # mark current iteration input Parameter node and execution condition Result node Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0]) Loop.mark_execution_condition_result_node(loop_node, body_results[-1]) # connect back edges in the body except current iteration for idx in range(1, len(body_parameters)): Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx]) # connect body outputs with Loop operation output ports except the execution condition result for idx in range(len(body_results) - 1): Loop.connect_body_output(loop_node, idx, body_results[idx]) # run function to parse body nodes attributes similar to the main graph extract_node_attrs( body_graph, lambda node: tf_op_extractor( node, check_for_duplicates(tf_op_extractors))) return cls.enabled
def extract(cls, loop_node): Loop.update_node_stat(loop_node, {}) body_graph_proto = onnx_attr(loop_node, 'body', 'g', None) main_graph = loop_node.graph # create a Graph object for the body and take graph attributes from the main graph body_graph = Graph() main_graph_attrs_copy = {} for attr_key, attr_value in main_graph.graph.items(): if attr_key not in ['tensor_mapping', 'parent_node']: main_graph_attrs_copy[attr_key] = copy.deepcopy(attr_value) body_graph.graph.update(main_graph_attrs_copy) loop_node['body'] = body_graph # save parent node for nested loops to know which node contains body (and which graph is on upper level) body_graph.graph['parent_node'] = loop_node # maps a tensor name to a node produced it and the node port: str -> (node_id, node_port) data_nodes_map = {} body_graph.graph['tensor_mapping'] = data_nodes_map # save mapping for possible Loop inside the Loop body_parameters = add_initializers_and_inputs_to_graph(body_graph, body_graph_proto, data_nodes_map) external_edges = [] # (src_node, src_out_port), dest_body_parameter_node # save additional edges information for graph on each level, the first one is the deepest additional_params = [] # (src_node, src_out_port) -> parameter_node (for manually added Parameters) # Go through all nodes in the original model order because data nodes are defined on-the-fly and order matters for pb_node in body_graph_proto.node: # create an NX node id = body_graph.unique_id(node_id(pb_node)) body_graph.add_node(id, pb=pb_node, kind='op') if hasattr(body_graph, 'op_names_statistic') and hasattr(pb_node, 'op_type'): body_graph.op_names_statistic[pb_node.op_type] += 1 # add incoming edges based on data_nodes_map for dst_port, inp in enumerate(pb_node.input): # should add edge src_internal_id --> dst_id if inp not in data_nodes_map: if inp == '': # input is omitted; most likely it corresponds to an optional input for an operator continue else: is_finished = create_cross_body_edge(body_graph, external_edges, additional_params, inp, id, dst_port) if not is_finished: raise Error( 'Reference to "{}" is not satisfied. A node refer not existing data tensor. ONNX ' 'model is not consistent. Protobuf fragment: {}', inp, pb_node) else: src_id, src_port = data_nodes_map[inp] create_edge_with_attrs(body_graph, inp, src_id, src_port, id, dst_port) # add outgoing edges to data_nodes_map for src_port, out in enumerate(pb_node.output): if out in data_nodes_map: log.debug("Detected reuse of blob {}.".format(out)) data_nodes_map[out] = (id, src_port) body_results = [] for output in body_graph_proto.output: tensor_name = str(output.name) node_name, output_port = data_nodes_map[tensor_name] assert body_graph.has_node(node_name), 'The body graph does not contain output with name "{}"'.format( node_name) body_results.append(Node(body_graph, add_opoutput(body_graph, node_name, output_port, False))) # add 'internal_layer_id' attribute which is a must have attribute for the loop body node for idx, body_node in enumerate(body_graph.get_op_nodes()): body_node['internal_layer_id'] = idx loop_carried_dependencies_count = len(body_graph_proto.input) - 2 scan_outputs_count = len(body_graph_proto.output) - 1 - loop_carried_dependencies_count # Loop inputs: # 0 - trip count # 1 - execution condition # 2 .. - loop carried dependencies # Loop outputs: # 0 .. loop_carried_dependencies_count - 1 - loop carried dependencies # loop_carried_dependencies_count .. - scan outputs # Body inputs: # 0 - iteration number # 1 - execution condition # 2 .. - loop carried dependencies # Body outputs: # 0 - execution condition # 1 .. loop_carried_dependencies_count - loop carried dependencies # loop_carried_dependencies_count + 1 .. - scan outputs # some of the inputs/outputs may not be connected but the normalization transformation will take care of it # connection Loop body nodes with external input edges next_loop_input_port_idx = sorted(loop_node.in_edges().keys())[-1] + 1 cur_graph = body_graph for external_edges_subg in external_edges: if 'parent_node' not in cur_graph.graph: continue cur_loop_node = cur_graph.graph['parent_node'] parent_graph = cur_loop_node.graph for (src_node, src_port), body_node, tensor_name in external_edges_subg: create_edge_with_attrs(parent_graph, tensor_name, src_node, src_port, cur_loop_node.id, next_loop_input_port_idx) Loop.connect_body_input(cur_loop_node, next_loop_input_port_idx, body_node) next_loop_input_port_idx += 1 cur_graph = parent_graph # mark current iteration input Parameter node Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0]) # connect initial value for "execution condition" input of the loop Loop.connect_body_input(loop_node, 1, body_parameters[1]) # add back edge with "execution condition" Loop.add_back_edge(loop_node, body_parameters[1], body_results[0]) # mark "execution condition" Result node Loop.mark_execution_condition_result_node(loop_node, body_results[0]) # connect initial value for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): Loop.connect_body_input(loop_node, idx + 2, body_parameters[idx + 2]) # add back edge for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): Loop.add_back_edge(loop_node, body_parameters[idx + 2], body_results[idx + 1]) # connect final value for "loop carried" dependencies variables for idx in range(loop_carried_dependencies_count): Loop.connect_body_output(loop_node, idx, body_results[idx + 1]) # connect "scan outputs" and mark axis for concatenation for idx in range(loop_carried_dependencies_count, loop_carried_dependencies_count + scan_outputs_count): Loop.connect_body_output(loop_node, idx, body_results[idx + 1], axis=0) # run function to parse body nodes attributes similar to the main graph extract_node_attrs(body_graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors))) return cls.enabled