def decode_name_with_port(input_model: InputModel, node_name: str, framework=""): """ Decode name with optional port specification w/o traversing all the nodes in the graph TODO: in future node_name can specify input/output port groups as well as indices (58562) :param input_model: Input Model :param node_name: user provided node name :return: decoded place in the graph """ found_nodes = [] found_node_names = [] node = input_model.get_place_by_tensor_name(node_name) if node: found_node_names.append('Tensor:' + node_name) found_nodes.append(node) def try_get_node(model, name, framework): node = model.get_place_by_operation_name(name) if node: return node if framework == "onnx": tensor = model.get_place_by_tensor_name(name) if tensor: return tensor.get_producing_operation() return None regexp_post = r'(.+):(\d+)' match_post = re.search(regexp_post, node_name) if match_post: node_post = try_get_node(input_model, match_post.group(1), framework) if node_post: node_post = node_post.get_output_port( output_port_index=int(match_post.group(2))) if node_post: found_node_names.append(match_post.group(1)) found_nodes.append(node_post) regexp_pre = r'(\d+):(.+)' match_pre = re.search(regexp_pre, node_name) if match_pre: node_pre = try_get_node(input_model, match_pre.group(2), framework) if node_pre: node_pre = node_pre.get_input_port( input_port_index=int(match_pre.group(1))) if node_pre: found_node_names.append(match_pre.group(2)) found_nodes.append(node_pre) if len(found_nodes) == 0: raise_no_node(node_name) # Check that there is no collision, all found places shall point to same data if not all([n.is_equal_data(found_nodes[0]) for n in found_nodes]): raise_node_name_collision(node_name, found_node_names) # TODO: Add support for input/output group name and port index here (58562) # For new frontends logic shall be extended to additionally support input and output group names return found_nodes[0]
def decode_name_with_port(input_model: InputModel, node_name: str): """ Decode name with optional port specification w/o traversing all the nodes in the graph TODO: in future node_name can specify input/output port groups as well as indices (58562) :param input_model: Input Model :param node_name: user provided node name :return: decoded place in the graph """ found_nodes = [] found_node_names = [] node = input_model.get_place_by_tensor_name(node_name) if node: found_node_names.append('Tensor:' + node_name) found_nodes.append(node) regexp_post = r'(.+):(\d+)' match_post = re.search(regexp_post, node_name) if match_post: node_post = input_model.get_place_by_operation_name( match_post.group(1)) if node_post: node_post = node_post.get_output_port( outputPortIndex=int(match_post.group(2))) if node_post: found_node_names.append(match_post.group(1)) found_nodes.append(node_post) regexp_pre = r'(\d+):(.+)' match_pre = re.search(regexp_pre, node_name) if match_pre: node_pre = input_model.get_place_by_operation_name(match_pre.group(2)) if node_pre: node_pre = node_pre.get_input_port( inputPortIndex=int(match_pre.group(1))) if node_pre: found_node_names.append(match_pre.group(2)) found_nodes.append(node_pre) if len(found_nodes) == 0: raise_no_node(node_name) # Check that there is no collision, all found places shall point to same data if not all([n.is_equal_data(found_nodes[0]) for n in found_nodes]): raise_node_name_collision(node_name, found_node_names) # TODO: ONNX specific (59408) # To comply with legacy behavior, for ONNX-only there shall be considered additional 2 possibilities # 1) "abc:1" - get_place_by_tensor_name("abc").get_producing_operation().get_output_port(1) # 2) "1:abc" - get_place_by_tensor_name("abc").get_producing_operation().get_input_port(1) # This logic is not going to work with other frontends # TODO: Add support for input/output group name and port index here (58562) # For new frontends logic shall be extended to additionally support input and output group names return found_nodes[0]
def decode_name_with_port(input_model: InputModel, node_name: str, framework=""): """ Decode name with optional port specification w/o traversing all the nodes in the graph TODO: in future node_name can specify input/output port groups as well as indices (58562) :param input_model: Input Model :param node_name: user provided node name :return: decoded place in the graph """ def extract_nodes(input_model, name, port, match_type, search_tensor=True): nodes = [] node_names = [] node = input_model.get_place_by_operation_name(name) if not node and search_tensor: tensor = input_model.get_place_by_tensor_name(name) if tensor: node_names.append('Tensor:' + tensor.get_names()[0]) nodes.append(tensor) search_tensor = False if node: # if there is an operation with given name, we add input port if match_type == MatchType.PRE: new_node = node.get_input_port(input_port_index=int(port)) elif match_type == MatchType.POST: new_node = node.get_output_port(output_port_index=int(port)) if new_node: node_names.append(name) nodes.append(new_node) # if we are still looking for the tensor e add one with given port if search_tensor: if match_type == MatchType.PRE: tensor = node.get_source_tensor(input_port_index=int(port)) elif match_type == MatchType.POST: tensor = node.get_target_tensor(output_port_index=0) if tensor: node_names.append('Tensor:' + tensor.get_names()[0]) nodes.append(tensor) search_tensor = False return nodes, node_names, search_tensor def try_get_nodes(input_model, node_name): # Passed node_name can be in several forms: # (1) name (2) port:name (3) name:port found_nodes = [] found_node_names = [] # if we find tensor, there is no need to continue searching search_tensor = True # check if there is a tensor with given node_name tensor = input_model.get_place_by_tensor_name(node_name) if tensor: found_node_names.append('Tensor:' + tensor.get_names()[0]) found_nodes.append(tensor) search_tensor = False regexp_pre = r'(\d+):(.+)' match_pre = re.search(regexp_pre, node_name) # we check for port:name combination if match_pre: nodes, node_names, search_tensor = extract_nodes( input_model, match_pre.group(2), match_pre.group(1), MatchType.PRE, search_tensor) if nodes: found_nodes += nodes found_node_names += node_names regexp_post = r'(.+):(\d+)' match_post = re.search(regexp_post, node_name) # we check for name:port combination if match_post: nodes, node_names, search_tensor = extract_nodes( input_model, match_post.group(1), match_post.group(2), MatchType.POST, search_tensor) if nodes: found_nodes += nodes found_node_names += node_names # if node and tensor were not found yet # we try to find operation with node_name if not found_nodes and search_tensor: node = input_model.get_place_by_operation_name(node_name) if node: tensor = node.get_target_tensor(output_port_index=0) if tensor: found_node_names.append('Tensor:' + tensor.get_names()[0]) found_nodes.append(tensor) return found_node_names, found_nodes found_node_names, found_nodes = try_get_nodes(input_model, node_name) if len(found_nodes) == 0: raise_no_node(node_name) # Check that there is no collision, all found places shall point to same data if not all([n.is_equal_data(found_nodes[0]) for n in found_nodes]): raise_node_name_collision(node_name, found_node_names) # TODO: Add support for input/output group name and port index here (58562) # For new frontends logic shall be extended to additionally support input and output group names idx = next( (idx for idx, name in enumerate(found_node_names) if 'Tensor' in name), 0) return found_nodes[idx]