示例#1
0
def decode_name_with_port(input_model: InputModel,
                          node_name: str,
                          framework=""):
    """
    Decode name with optional port specification w/o traversing all the nodes in the graph
    TODO: in future node_name can specify input/output port groups as well as indices (58562)
    :param input_model: Input Model
    :param node_name: user provided node name
    :return: decoded place in the graph
    """
    found_nodes = []
    found_node_names = []

    node = input_model.get_place_by_tensor_name(node_name)
    if node:
        found_node_names.append('Tensor:' + node_name)
        found_nodes.append(node)

    def try_get_node(model, name, framework):
        node = model.get_place_by_operation_name(name)
        if node:
            return node
        if framework == "onnx":
            tensor = model.get_place_by_tensor_name(name)
            if tensor:
                return tensor.get_producing_operation()
        return None

    regexp_post = r'(.+):(\d+)'
    match_post = re.search(regexp_post, node_name)
    if match_post:
        node_post = try_get_node(input_model, match_post.group(1), framework)
        if node_post:
            node_post = node_post.get_output_port(
                output_port_index=int(match_post.group(2)))
            if node_post:
                found_node_names.append(match_post.group(1))
                found_nodes.append(node_post)

    regexp_pre = r'(\d+):(.+)'
    match_pre = re.search(regexp_pre, node_name)
    if match_pre:
        node_pre = try_get_node(input_model, match_pre.group(2), framework)
        if node_pre:
            node_pre = node_pre.get_input_port(
                input_port_index=int(match_pre.group(1)))
            if node_pre:
                found_node_names.append(match_pre.group(2))
                found_nodes.append(node_pre)

    if len(found_nodes) == 0:
        raise_no_node(node_name)

    # Check that there is no collision, all found places shall point to same data
    if not all([n.is_equal_data(found_nodes[0]) for n in found_nodes]):
        raise_node_name_collision(node_name, found_node_names)

    # TODO: Add support for input/output group name and port index here (58562)
    # For new frontends logic shall be extended to additionally support input and output group names
    return found_nodes[0]
示例#2
0
def decode_name_with_port(input_model: InputModel, node_name: str):
    """
    Decode name with optional port specification w/o traversing all the nodes in the graph
    TODO: in future node_name can specify input/output port groups as well as indices (58562)
    :param input_model: Input Model
    :param node_name: user provided node name
    :return: decoded place in the graph
    """
    found_nodes = []
    found_node_names = []

    node = input_model.get_place_by_tensor_name(node_name)
    if node:
        found_node_names.append('Tensor:' + node_name)
        found_nodes.append(node)

    regexp_post = r'(.+):(\d+)'
    match_post = re.search(regexp_post, node_name)
    if match_post:
        node_post = input_model.get_place_by_operation_name(
            match_post.group(1))
        if node_post:
            node_post = node_post.get_output_port(
                outputPortIndex=int(match_post.group(2)))
            if node_post:
                found_node_names.append(match_post.group(1))
                found_nodes.append(node_post)

    regexp_pre = r'(\d+):(.+)'
    match_pre = re.search(regexp_pre, node_name)
    if match_pre:
        node_pre = input_model.get_place_by_operation_name(match_pre.group(2))
        if node_pre:
            node_pre = node_pre.get_input_port(
                inputPortIndex=int(match_pre.group(1)))
            if node_pre:
                found_node_names.append(match_pre.group(2))
                found_nodes.append(node_pre)

    if len(found_nodes) == 0:
        raise_no_node(node_name)

    # Check that there is no collision, all found places shall point to same data
    if not all([n.is_equal_data(found_nodes[0]) for n in found_nodes]):
        raise_node_name_collision(node_name, found_node_names)

    # TODO: ONNX specific (59408)
    # To comply with legacy behavior, for ONNX-only there shall be considered additional 2 possibilities
    # 1) "abc:1" - get_place_by_tensor_name("abc").get_producing_operation().get_output_port(1)
    # 2) "1:abc" - get_place_by_tensor_name("abc").get_producing_operation().get_input_port(1)
    # This logic is not going to work with other frontends

    # TODO: Add support for input/output group name and port index here (58562)
    # For new frontends logic shall be extended to additionally support input and output group names
    return found_nodes[0]
示例#3
0
def decode_name_with_port(input_model: InputModel,
                          node_name: str,
                          framework=""):
    """
    Decode name with optional port specification w/o traversing all the nodes in the graph
    TODO: in future node_name can specify input/output port groups as well as indices (58562)
    :param input_model: Input Model
    :param node_name: user provided node name
    :return: decoded place in the graph
    """
    def extract_nodes(input_model, name, port, match_type, search_tensor=True):
        nodes = []
        node_names = []
        node = input_model.get_place_by_operation_name(name)
        if not node and search_tensor:
            tensor = input_model.get_place_by_tensor_name(name)
            if tensor:
                node_names.append('Tensor:' + tensor.get_names()[0])
                nodes.append(tensor)
                search_tensor = False
        if node:
            # if there is an operation with given name, we add input port
            if match_type == MatchType.PRE:
                new_node = node.get_input_port(input_port_index=int(port))
            elif match_type == MatchType.POST:
                new_node = node.get_output_port(output_port_index=int(port))
            if new_node:
                node_names.append(name)
                nodes.append(new_node)
            # if we are still looking for the tensor e add one with given port
            if search_tensor:
                if match_type == MatchType.PRE:
                    tensor = node.get_source_tensor(input_port_index=int(port))
                elif match_type == MatchType.POST:
                    tensor = node.get_target_tensor(output_port_index=0)
                if tensor:
                    node_names.append('Tensor:' + tensor.get_names()[0])
                    nodes.append(tensor)
                    search_tensor = False
        return nodes, node_names, search_tensor

    def try_get_nodes(input_model, node_name):
        # Passed node_name can be in several forms:
        # (1) name (2) port:name (3) name:port
        found_nodes = []
        found_node_names = []
        # if we find tensor, there is no need to continue searching
        search_tensor = True
        # check if there is a tensor with given node_name
        tensor = input_model.get_place_by_tensor_name(node_name)
        if tensor:
            found_node_names.append('Tensor:' + tensor.get_names()[0])
            found_nodes.append(tensor)
            search_tensor = False

        regexp_pre = r'(\d+):(.+)'
        match_pre = re.search(regexp_pre, node_name)
        # we check for port:name combination
        if match_pre:
            nodes, node_names, search_tensor = extract_nodes(
                input_model, match_pre.group(2), match_pre.group(1),
                MatchType.PRE, search_tensor)
            if nodes:
                found_nodes += nodes
                found_node_names += node_names
        regexp_post = r'(.+):(\d+)'
        match_post = re.search(regexp_post, node_name)
        # we check for name:port combination
        if match_post:
            nodes, node_names, search_tensor = extract_nodes(
                input_model, match_post.group(1), match_post.group(2),
                MatchType.POST, search_tensor)
            if nodes:
                found_nodes += nodes
                found_node_names += node_names
        # if node and tensor were not found yet
        # we try to find operation with node_name
        if not found_nodes and search_tensor:
            node = input_model.get_place_by_operation_name(node_name)
            if node:
                tensor = node.get_target_tensor(output_port_index=0)
                if tensor:
                    found_node_names.append('Tensor:' + tensor.get_names()[0])
                    found_nodes.append(tensor)
        return found_node_names, found_nodes

    found_node_names, found_nodes = try_get_nodes(input_model, node_name)
    if len(found_nodes) == 0:
        raise_no_node(node_name)

    # Check that there is no collision, all found places shall point to same data
    if not all([n.is_equal_data(found_nodes[0]) for n in found_nodes]):
        raise_node_name_collision(node_name, found_node_names)

    # TODO: Add support for input/output group name and port index here (58562)
    # For new frontends logic shall be extended to additionally support input and output group names
    idx = next(
        (idx for idx, name in enumerate(found_node_names) if 'Tensor' in name),
        0)
    return found_nodes[idx]