def get_tf_edges(node: Node):
    """
    By TF/NX node find all inputs and return list of all edges.
    Edge direction represents data flow (from source op to this node).
    So the resulting list contains all input edges for a given node.
    Edge attributes: 'in' is index of input port for a given node, 'out' is an index
    of output port of some other node that produces input data for this node.
    """
    edge_list = []
    for in_port, src_node_id in enumerate(node.pb.input):
        src_node, src_port = get_tf_node_port(src_node_id)
        cf_flag = False
        if src_node[0] == '^':
            src_node = src_node[1:]
            cf_flag = True
        edge = (
            src_node,
            node.id,
            {
                'in': in_port,
                'out': src_port,
                'fw_tensor_debug_info': [
                    (src_node_id, src_port)
                ],  # debug anchor for a framework tensor name and port
                'in_attrs': ['in', 'control_flow_edge', 'permutation'],
                'out_attrs': ['out', 'permutation'],
                'data_attrs': ['fw_tensor_debug_info'],
                'control_flow_edge': cf_flag
            })
        edge_list.append(edge)
    return edge_list
Exemple #2
0
def add_placeholders_to_subgraph(node: Node):
    """
    Adds placeholders to the node's list of protobufs based on input nodes to the subgraph (the value of
    'internal_input_node_name' property).
    The function also updates input tensors for nodes which consume output of nodes that were replaced with
    placeholders.
    :param node: the node to add placeholders to.
    :return: None
    """
    inputs_replacements = list()
    for index, (in_data_node,
                edge_attrs) in enumerate(node.get_sorted_inputs()):
        if 'control_flow_edge' in edge_attrs and edge_attrs[
                'control_flow_edge']:
            continue

        if 'internal_input_node_name' in edge_attrs.keys():
            input_tensor_name = edge_attrs['internal_input_node_name']
        else:
            input_tensor_name = node['pb'].input[index]

        input_node_name, port = get_tf_node_port(input_tensor_name)

        placeholder_name = placeholder_name_for_node(input_node_name, port)
        edge_attrs['placeholder_name'] = placeholder_name
        in_node = node.in_node(index)

        assert in_node.shape is not None

        if placeholder_name not in node['pbs'].keys():
            placeholder = tf_v1.placeholder(determine_data_type(in_node),
                                            in_node.shape, placeholder_name)
            inputs_replacements.append((input_tensor_name, placeholder_name))
            add_node_def_to_subgraph(node,
                                     placeholder.op.node_def,
                                     is_input=True)
            log.debug(
                "Added placeholder with name '{}'".format(placeholder_name))

    # update initial input names to a transposed ones
    for old_input_tensor_name, new_name in inputs_replacements:
        update_input_in_pbs(node, old_input_tensor_name, new_name)
Exemple #3
0
def create_tf_edge(src_node_id: str, dst_node_id: str, in_port: int):
    """
    Creates an edge for given nodes and input port.
    """
    src_node, src_port = get_tf_node_port(src_node_id)
    tensor_name = src_node + ":" + str(src_port)
    cf_flag = False
    if src_node[0] == '^':
        src_node = src_node[1:]
        cf_flag = True
    return (src_node, dst_node_id, {
        'in': in_port,
        'out': src_port,
        # debug anchor for a framework name, out port and tensor name
        'fw_tensor_debug_info': [(src_node_id, src_port, tensor_name)],
        'in_attrs': ['in', 'control_flow_edge', 'permutation'],
        'out_attrs': ['out', 'permutation'],
        'data_attrs': ['fw_tensor_debug_info'],
        'control_flow_edge': cf_flag
    })