def _rename_edges_helper(internal_node: NodeProto, rename_helper: Callable[[Text], Text], attribute_map: Dict[Text, AttributeProto], prefix: Text) -> NodeProto: new_node = NodeProto() new_node.CopyFrom(internal_node) new_node.ClearField("input") new_node.ClearField("output") new_node.ClearField("attribute") for internal_name in internal_node.input: new_node.input.append(rename_helper(internal_name)) for internal_name in internal_node.output: new_node.output.append(rename_helper(internal_name)) for attr in internal_node.attribute: if attr.HasField("ref_attr_name"): if attr.ref_attr_name in attribute_map: new_attr = AttributeProto() new_attr.CopyFrom( attribute_map[attr.ref_attr_name]) # type: ignore new_attr.name = attr.name new_node.attribute.extend([new_attr]) else: new_attr = AttributeProto() new_attr.CopyFrom(attr) if attr.type == AttributeProto.GRAPH: new_graph = new_attr.g sg_rename = {} for in_desc in new_graph.input: sg_rename[ in_desc.name] = in_desc.name = prefix + in_desc.name for out_desc in new_graph.output: sg_rename[ out_desc.name] = out_desc.name = prefix + out_desc.name for init_desc in new_graph.initializer: sg_rename[init_desc. name] = init_desc.name = prefix + init_desc.name for sparse_init_desc in new_graph.sparse_initializer: sg_rename[sparse_init_desc.values.name] = sparse_init_desc.values.name = prefix + \ sparse_init_desc.values.name for sparse_init_desc in new_graph.sparse_initializer: sg_rename[sparse_init_desc.indices.name] = sparse_init_desc.indices.name = prefix + \ sparse_init_desc.indices.name def subgraph_rename_helper(name: Text) -> Any: if name in sg_rename: return sg_rename[name] else: return rename_helper(name) new_nodes = [ _rename_edges_helper(node_desc, subgraph_rename_helper, attribute_map, prefix) for node_desc in new_graph.node ] new_graph.ClearField("node") new_graph.node.extend(new_nodes) new_node.attribute.extend([new_attr]) return new_node
def function_expand_helper( node, # type: NodeProto function_proto, # type: FunctionProto op_prefix # type: Text ): # type: (...) -> List[NodeProto] node_list = [] io_names_map = dict() attribute_map = dict((a.name, a) for a in node.attribute) for idx in range(len(function_proto.input)): io_names_map[function_proto.input[idx]] = node.input[idx] \ if idx in range(len(node.input)) else "" for idx in range(len(function_proto.output)): # Even if the node has been created with optional outputs missing, we # can't assume that the function body handles this correctly, such as in # the case that output is also an intermediate value. # So we only add a name mapping if the output is present. An internal # name will be generated if the missing output is used, the same as any # other internal tensor. if idx in range(len(node.output)) and node.output[idx] != "": io_names_map[function_proto.output[idx]] = node.output[idx] for internal_node in function_proto.node: new_node = NodeProto() new_node.CopyFrom(internal_node) new_node.ClearField("input") new_node.ClearField("output") new_node.ClearField("attribute") for internal_name in internal_node.input: if internal_name in io_names_map: new_node.input.append(io_names_map[internal_name]) else: new_node.input.append(op_prefix + internal_name) for internal_name in internal_node.output: if internal_name in io_names_map: new_node.output.append(io_names_map[internal_name]) else: new_node.output.append(op_prefix + internal_name) for attr in internal_node.attribute: if attr.HasField("ref_attr_name"): if attr.ref_attr_name in attribute_map: new_attr = AttributeProto() new_attr.CopyFrom( attribute_map[attr.ref_attr_name]) # type: ignore new_attr.name = attr.name new_node.attribute.extend([new_attr]) else: new_attr = AttributeProto() new_attr.CopyFrom(attr) new_node.attribute.extend([new_attr]) node_list.append(new_node) return node_list
def function_expand_helper( node, # type: NodeProto function_proto, # type: FunctionProto op_prefix # type: Text ): # type: (...) -> List[NodeProto] node_list = [] input_names_map = dict() output_names_map = dict() attribute_map = node.attribute for idx in range(len(function_proto.input)): input_names_map[function_proto.input[idx]] = node.input[idx] \ if idx in range(len(node.input)) else "" for idx in range(len(function_proto.output)): output_names_map[function_proto.output[idx]] = node.output[idx] \ if idx in range(len(node.output)) else "" for internal_node in function_proto.node: new_node = NodeProto() new_node.CopyFrom(internal_node) new_node.ClearField("input") new_node.ClearField("output") new_node.ClearField("attribute") for internal_name in internal_node.input: if internal_name in input_names_map: new_node.input.append(input_names_map[internal_name]) else: new_node.input.append(op_prefix + internal_name) for internal_name in internal_node.output: if internal_name in output_names_map: new_node.output.append(output_names_map[internal_name]) else: new_node.output.append(op_prefix + internal_name) for attr in internal_node.attribute: if attr.HasField("ref_attr_name"): if attr.ref_attr_name in attribute_map: new_attr = AttributeProto() new_attr.CopyFrom( attribute_map[attr.ref_attr_name]) # type: ignore new_node.attribute.extend([new_attr]) else: new_attr = AttributeProto() new_attr.CopyFrom(attr) new_node.attribute.extend([new_attr]) node_list.append(new_node) return node_list