def _from_proto_fn(v, import_scope=None): if has_distribution_strategy(): raise NotImplementedError( "Deserialization of variables is not yet supported when using" "distributed strategies.") else: resource_variable_ops._from_proto_fn(v, import_scope=import_scope)
def trainable_var_op_to_var(self): """ Mapping from trainable variable ops (e.g. VarHandleOps) to the Variables. Returns: Dict """ with self.graph.as_default(): return { self.graph.get_operation_by_name( get_op_name(var_def.variable_name)): _from_proto_fn(var_def) for var_def in self.info.trainable_variables }
def replicate(self, graph_item): """ Replicate the entire graph as many times as num_replica. Args: graph_item: the original graph item Returns: The new graph item """ item = GraphItem(graph=ops.Graph()) fwd_ctx, bwd_ctx = self._collect_while_context(graph_item.graph) with item.graph.as_default(): gdef = graph_item.graph.as_graph_def() for i in range(self._num_local_replicas): # Replicate ops with ops.device(self._replica_device_placer(replica_id=i)): import_graph_def(gdef, name=replica_prefix(i)) # Replicate while_loop context (control_flow) if needed. # The order matters -- We must replicate bwd context first, then forward context. # TODO(Zeya): To handle cases when there are nested while loops, in which we must replicate # parent context first and then child context. See: # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L938 if bwd_ctx: for ctx in bwd_ctx: _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state, import_scope=replica_prefix(i)) if fwd_ctx: for ctx in fwd_ctx: _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state, import_scope=replica_prefix(i)) # update saver master_replica = 0 if graph_item.info.savers: item.info.update_savers( [Saver.from_proto(proto, import_scope=replica_prefix(master_replica)).to_proto() for proto in graph_item.info.savers], replace=False ) # update gradient info for i in range(self._num_local_replicas): for g_name, t_name in graph_item.grad_target_name_pairs.items(): if isinstance(g_name, tuple): new_g_name = ( ops.prepend_name_scope(g_name[0], replica_prefix(i)), ops.prepend_name_scope(g_name[1], replica_prefix(i)), ops.prepend_name_scope(g_name[2], replica_prefix(i))) else: new_g_name = ops.prepend_name_scope(g_name, replica_prefix(i)) new_t_name = ops.prepend_name_scope(t_name, replica_prefix(i)) item.extend_gradient_info_by_names( grads=[new_g_name], targets=[new_t_name] ) item.info.update_variables( [_from_proto_fn(proto, import_scope=replica_prefix(i)).to_proto() for proto in graph_item.info.variables], replace=False ) item.info.update_table_initializers( [ops.prepend_name_scope(tb_init, replica_prefix(i)) for tb_init in graph_item.info.table_initializers], replace=False ) return item
def get_all_variables(self): """Get all variables in this graph item.""" with self.graph.as_default(): return [_from_proto_fn(var_def) for var_def in self.info.variables]