Exemplo n.º 1
0
    def device_policy(self):
        # Only get the policy from the context if it has already been initialized
        if self._context_handle is not None:
            return pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
                self._handle)

        return self._device_policy
    def __init__(self, dist, coord, replica_id, device_map,
                 variable_creator_fn, fn, args, kwargs):
        super(_MirroredReplicaThread, self).__init__()
        self.coord = coord
        self.distribution = dist
        self.device_map = device_map
        self.replica_id = replica_id
        self.variable_creator_fn = variable_creator_fn
        # State needed to run and return the results of `fn`.
        self.main_fn = fn
        self.main_args = args
        self.main_kwargs = kwargs
        self.main_result = None
        self.done = False
        # State needed to run the next merge_call() (if any) requested via
        # ReplicaContext.
        self.merge_fn = None
        self.merge_args = None
        self.merge_kwargs = None
        self.merge_result = None
        self.captured_name_scope = None
        self.captured_var_scope = None
        # We use a thread.Event for the main thread to signal when this
        # thread should start running (`should_run`), and another for
        # this thread to transfer control back to the main thread
        # (`has_paused`, either when it gets to a
        # `get_replica_context().merge_call` or when `fn` returns). In
        # either case the event starts cleared, is signaled by calling
        # set(). The receiving thread waits for the signal by calling
        # wait() and then immediately clearing the event using clear().
        self.should_run = threading.Event()
        self.has_paused = threading.Event()
        # These fields have to do with inheriting various contexts from the
        # parent thread:
        ctx = context.context()
        self.in_eager = ctx.executing_eagerly()
        self.record_thread_local_context_fields()
        # pylint: disable=protected-access
        if not ctx._context_handle:
            ctx._initialize_handle_and_devices()
        self.context_device_policy = (
            pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
                ctx._context_handle))
        self.graph = ops.get_default_graph()
        with ops.init_scope():
            self._init_in_eager = context.executing_eagerly()
            self._init_graph = ops.get_default_graph()

        self._variable_creator_stack = self.graph._variable_creator_stack[:]
        self._var_scope = variable_scope.get_variable_scope()
        # Adding a "/" at end lets us re-enter this scope later.
        self._name_scope = self.graph.get_name_scope()
        if self._name_scope:
            self._name_scope += "/"
        if self.replica_id > 0:
            if not self._name_scope:
                self._name_scope = ""
            self._name_scope += "replica_%d/" % self.replica_id
Exemplo n.º 3
0
 def device_policy(self, policy):
     handle = self._handle
     old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(handle)
     pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
         handle, policy)
     try:
         yield
     finally:
         pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
             handle, old)
Exemplo n.º 4
0
 def device_policy(self, policy):
     if not self._context_handle:
         self._initialize_handle_and_devices()
     old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
         self._context_handle)
     pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
         self._handle, policy)
     try:
         yield
     finally:
         pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
             self._handle, old)