def get_tf_edges(node: Node): """ By TF/NX node find all inputs and return list of all edges. Edge direction represents data flow (from source op to this node). So the resulting list contains all input edges for a given node. Edge attributes: 'in' is index of input port for a given node, 'out' is an index of output port of some other node that produces input data for this node. """ edge_list = [] for in_port, src_node_id in enumerate(node.pb.input): src_node, src_port = get_tf_node_port(src_node_id) cf_flag = False if src_node[0] == '^': src_node = src_node[1:] cf_flag = True edge = ( src_node, node.id, { 'in': in_port, 'out': src_port, 'fw_tensor_debug_info': [ (src_node_id, src_port) ], # debug anchor for a framework tensor name and port 'in_attrs': ['in', 'control_flow_edge', 'permutation'], 'out_attrs': ['out', 'permutation'], 'data_attrs': ['fw_tensor_debug_info'], 'control_flow_edge': cf_flag }) edge_list.append(edge) return edge_list
def add_placeholders_to_subgraph(node: Node): """ Adds placeholders to the node's list of protobufs based on input nodes to the subgraph (the value of 'internal_input_node_name' property). The function also updates input tensors for nodes which consume output of nodes that were replaced with placeholders. :param node: the node to add placeholders to. :return: None """ inputs_replacements = list() for index, (in_data_node, edge_attrs) in enumerate(node.get_sorted_inputs()): if 'control_flow_edge' in edge_attrs and edge_attrs[ 'control_flow_edge']: continue if 'internal_input_node_name' in edge_attrs.keys(): input_tensor_name = edge_attrs['internal_input_node_name'] else: input_tensor_name = node['pb'].input[index] input_node_name, port = get_tf_node_port(input_tensor_name) placeholder_name = placeholder_name_for_node(input_node_name, port) edge_attrs['placeholder_name'] = placeholder_name in_node = node.in_node(index) assert in_node.shape is not None if placeholder_name not in node['pbs'].keys(): placeholder = tf_v1.placeholder(determine_data_type(in_node), in_node.shape, placeholder_name) inputs_replacements.append((input_tensor_name, placeholder_name)) add_node_def_to_subgraph(node, placeholder.op.node_def, is_input=True) log.debug( "Added placeholder with name '{}'".format(placeholder_name)) # update initial input names to a transposed ones for old_input_tensor_name, new_name in inputs_replacements: update_input_in_pbs(node, old_input_tensor_name, new_name)
def create_tf_edge(src_node_id: str, dst_node_id: str, in_port: int): """ Creates an edge for given nodes and input port. """ src_node, src_port = get_tf_node_port(src_node_id) tensor_name = src_node + ":" + str(src_port) cf_flag = False if src_node[0] == '^': src_node = src_node[1:] cf_flag = True return (src_node, dst_node_id, { 'in': in_port, 'out': src_port, # debug anchor for a framework name, out port and tensor name 'fw_tensor_debug_info': [(src_node_id, src_port, tensor_name)], 'in_attrs': ['in', 'control_flow_edge', 'permutation'], 'out_attrs': ['out', 'permutation'], 'data_attrs': ['fw_tensor_debug_info'], 'control_flow_edge': cf_flag })