def _batch_prepend_name_scope(self, to_rename, new_name_scope): """ Construct a new GraphItem with all ops in `to_rename` under `new_name_scope`. Args: to_rename (set): Collection of ops to rename new_name_scope (str): The new name scope to prepend to all ops Returns: GraphItem """ og_graph_def = self.graph_item.graph.as_graph_def() new_graph_def = graph_pb2.GraphDef() new_graph_def.library.Clear() new_graph_def.library.CopyFrom(og_graph_def.library) control_flow_contexts = {} for node in og_graph_def.node: op = self.graph_item.graph.get_operation_by_name(node.name) # Save control flow context to add it back later # Since it is not automatically set based on the attr's in the graph def ctx = op._get_control_flow_context() if ctx: control_flow_contexts[op.name] = ctx if op in to_rename: node.name = ops.prepend_name_scope(node.name, new_name_scope) # Rename inputs for idx, input_name in enumerate(node.input): input_op = self.graph_item.graph.get_operation_by_name( get_op_name(input_name)) if input_op in to_rename: node.input[idx] = ops.prepend_name_scope( input_name, new_name_scope) # Fix colocation for idx, s in enumerate(node.attr['_class'].list.s): name = s[len(COLOCATION_PREFIX):].decode('utf-8') if self.graph_item.graph.get_operation_by_name( name) in to_rename: node.attr['_class'].list.s[idx] = ( COLOCATION_PREFIX + as_bytes(ops.prepend_name_scope(name, new_name_scope))) new_graph_def.node.append(node) # Re-add control flow contexts new_graph_item = GraphItem(graph_def=new_graph_def) for op in new_graph_item.graph.get_operations(): if op.name in control_flow_contexts: op._set_control_flow_context(control_flow_contexts[op.name]) return new_graph_item
def _delete_marked_ops(self, graph_item, name_scope): """ Constructs a new GraphItem with all ops under `name_scope` removed. Args: graph_item (GraphItem): The current GraphItem. name_scope (str): The name scope to remove. Returns: GraphItem """ graph_def = graph_item.graph.as_graph_def() new_graph_def = graph_pb2.GraphDef() new_graph_def.library.Clear() new_graph_def.library.CopyFrom(graph_def.library) control_flow_contexts = {} for node in graph_def.node: if parse_name_scope(node.name).startswith(name_scope): continue # Save control flow context to add it back later # Since it is not automatically set based on the attr's in the graph def op = graph_item.graph.get_operation_by_name(node.name) ctx = op._get_control_flow_context() if ctx: control_flow_contexts[op.name] = ctx for idx, input_name in enumerate(node.input): if parse_name_scope(input_name).startswith(name_scope): node.input[idx] = "" for idx, s in enumerate(node.attr['_class'].list.s): name = s[len(COLOCATION_PREFIX):].decode('utf-8') if parse_name_scope(name).startswith(name_scope): node.attr['_class'].list.s.remove(s) self._prune_graphdef_node_inputs(node) new_graph_def.node.append(node) # Re-add control flow contexts new_graph_item = GraphItem(graph_def=new_graph_def) for op in new_graph_item.graph.get_operations(): if op.name in control_flow_contexts: op._set_control_flow_context(control_flow_contexts[op.name]) return new_graph_item
def _initialize_graph(self): """Postpone the initialization of the member original_graph_item to the scoping time.""" assert not tf_context.executing_eagerly() self._original_graph_item = GraphItem(graph=ops.get_default_graph())
def replicate(self, graph_item): """ Replicate the entire graph as many times as num_replica. Args: graph_item: the original graph item Returns: The new graph item """ item = GraphItem(graph=ops.Graph()) fwd_ctx, bwd_ctx = self._collect_while_context(graph_item.graph) with item.graph.as_default(): gdef = graph_item.graph.as_graph_def() for i in range(self._num_local_replicas): # Replicate ops with ops.device(self._replica_device_placer(replica_id=i)): import_graph_def(gdef, name=replica_prefix(i)) # Replicate while_loop context (control_flow) if needed. # The order matters -- We must replicate bwd context first, then forward context. # TODO(Zeya): To handle cases when there are nested while loops, in which we must replicate # parent context first and then child context. See: # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L938 if bwd_ctx: for ctx in bwd_ctx: _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state, import_scope=replica_prefix(i)) if fwd_ctx: for ctx in fwd_ctx: _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state, import_scope=replica_prefix(i)) # update saver master_replica = 0 if graph_item.info.savers: item.info.update_savers( [Saver.from_proto(proto, import_scope=replica_prefix(master_replica)).to_proto() for proto in graph_item.info.savers], replace=False ) # update gradient info for i in range(self._num_local_replicas): for g_name, t_name in graph_item.grad_target_name_pairs.items(): if isinstance(g_name, tuple): new_g_name = ( ops.prepend_name_scope(g_name[0], replica_prefix(i)), ops.prepend_name_scope(g_name[1], replica_prefix(i)), ops.prepend_name_scope(g_name[2], replica_prefix(i))) else: new_g_name = ops.prepend_name_scope(g_name, replica_prefix(i)) new_t_name = ops.prepend_name_scope(t_name, replica_prefix(i)) item.extend_gradient_info_by_names( grads=[new_g_name], targets=[new_t_name] ) item.info.update_variables( [_from_proto_fn(proto, import_scope=replica_prefix(i)).to_proto() for proto in graph_item.info.variables], replace=False ) item.info.update_table_initializers( [ops.prepend_name_scope(tb_init, replica_prefix(i)) for tb_init in graph_item.info.table_initializers], replace=False ) return item