def testCancelDuringParamResolution(self, collective_op, device, communication): dev0 = '/device:%s:0' % device dev1 = '/device:%s:1' % device group_size = 2 group_key = 100 instance_key = 100 in_tensor = constant_op.constant([1.]) t1_cancellation_manager = cancellation.CancellationManager() t2_cancellation_manager = cancellation.CancellationManager() @def_function.function def _collective_fn(x): # Run an assertion to crash one of the two function executions running # collectives. We explicitly cancel the other in response. assert_op = check_ops.assert_equal(x, in_tensor) with ops.control_dependencies([assert_op]): return collective_op( in_tensor, group_size, group_key, instance_key, communication_hint=communication) collective_concrete = _collective_fn.get_concrete_function(in_tensor) finish_mu = threading.Lock() finishes = 0 def _placement_wrapper(device, x, my_cancellation, other_cancellation): try: with ops.device(device): cancelable_collective = my_cancellation.get_cancelable_function( collective_concrete) return cancelable_collective(x) except errors.InvalidArgumentError: # `assert_equal` failed for this execution of the function. The other # function would deadlock without cancellation. other_cancellation.start_cancel() except errors.CancelledError: pass nonlocal finishes with finish_mu: finishes += 1 t1 = threading.Thread( target=_placement_wrapper, args=(dev0, constant_op.constant([1.]), t1_cancellation_manager, t2_cancellation_manager)) t2 = threading.Thread( target=_placement_wrapper, # Will cause the assertion to fail args=(dev1, constant_op.constant([2.]), t2_cancellation_manager, t1_cancellation_manager)) t1.start() t2.start() t1.join() t2.join() self.assertEqual(finishes, 2)
def testCancelRemoteFunctionDuringExecution(self): remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE' default_streaming = os.environ.get(remote_async_env_var) os.environ[remote_async_env_var] = str(False) q = data_flow_ops.FIFOQueue(1, dtypes.int32) @def_function.function def f(): return q.dequeue() c_mgr = cancellation.CancellationManager() cancelable_func = c_mgr.get_cancelable_function( f.get_concrete_function()) def cancel_thread(): time.sleep(0.5) c_mgr.start_cancel() t = self.checkedThread(cancel_thread) t.start() with self.assertRaises(errors.CancelledError): with ops.device('/job:worker/replica:0/task:1'): cancelable_func() t.join() if default_streaming is None: del os.environ[remote_async_env_var] else: os.environ[remote_async_env_var] = default_streaming
def testCancelGetNextWithDevice(self, cls): ping = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64) pong = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64) @def_function.function def map_fn(v): ball = ping.dequeue() with ops.control_dependencies([pong.enqueue(ball)]): return v + ping.dequeue() dataset = dataset_ops.Dataset.range(10) dataset = dataset.map(map_fn) # We need to set prefetch_buffer_size=0 so that we can cancel the # MultiDeviceIteratorGetNextFromShardOp from eager. If # prefetch_buffer_size>0, that op runs in the background threads of the # prefetch and can only be cancelled by deleting the iterator. multi_device_iterator = cls( dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=0) @def_function.function def get_next_device1(): return multi_device_iterator.get_next(self._devices[1]) async_executor = executor.new_executor(enable_async=True) with context.executor_scope(async_executor): cancel_mgr = cancellation.CancellationManager() cancel_mgr.get_cancelable_function( get_next_device1.get_concrete_function())() # Make sure we cancel in the middle of get_next. ping.enqueue(0) pong.dequeue() cancel_mgr.start_cancel() with self.assertRaises(errors.CancelledError): async_executor.wait() # Note that fetching from upstream iterator is not cancelled with the # cancellation of get_next. ping.enqueue(0) # Cancelling a get_next on one device shouldn't cancel the # multi_device_iterator and iterators on other devices. ping.enqueue(0) ping.enqueue(0) self.assertEqual(1, multi_device_iterator.get_next(self._devices[2]).numpy()) # FIXME(b/209534797): Workaround an asan error caused by this test. # Remove the dangling reference from tf.function to ensure queue objects # are not freed before they are flushed. import gc # pylint: disable=g-import-not-at-top del get_next_device1 gc.collect()
def __init__(self): # `self._inflight_closure_count` only tracks the number of inflight closures # that are "in generation". Once an error occurs, error generation is # incremented and all subsequent arriving closures (from inflight) are # considered "out of generation". self._inflight_closure_count = 0 self._queue_lock = threading.Lock() # Condition indicating that all pending closures (either queued or inflight) # have been processed, failed, or cancelled. self._stop_waiting_condition = threading.Condition(self._queue_lock) # Condition indicating that an item becomes available in queue (not empty). self._closures_queued_condition = threading.Condition(self._queue_lock) # Condition indicating that a queue slot becomes available (not full). # Note that even with "infinite" queue size, there is still a "practical" # size limit for the queue depending on host memory capacity, and thus the # queue will eventually become full with a lot of enqueued closures. self._queue_free_slot_condition = threading.Condition(self._queue_lock) # Condition indicating there is no inflight closures. self._no_inflight_closure_condition = threading.Condition(self._queue_lock) # Use to cancel in-flight closures. self._cancellation_mgr = cancellation.CancellationManager() if _CLOSURE_QUEUE_MAX_SIZE <= 0: logging.warning( "In a `Client`, creating an infinite closure queue can " "consume a significant amount of memory and even lead to OOM.") self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE) self._error = None # The following is a lock to make sure when `wait` is called and before it # returns no `put` can be executed during this period. It is because `wait` # won't know what to do with newly put closures. This lock adds an cutoff # for `wait` so that closures put into the queue while waiting would not be # taken responsible by this `wait`. # # We cannot reuse the `self._queue_lock` since when `wait` waits for a # condition, the `self._queue_lock` will be released. # # We don't use a reader/writer's lock on purpose to reduce the complexity # of the code. self._put_wait_lock = threading.Lock()
def testProcessAtLeaseOnce(self): closure_queue = coordinator_lib._CoordinatedClosureQueue() labels = ['A', 'B', 'C', 'D', 'E'] processed_count = collections.defaultdict(int) coord = coordinator.Coordinator(clean_stop_exception_types=[]) def process_queue(): with coord.stop_on_exception(): has_been_put_back = False while True: closure = closure_queue.get(timeout=30) if closure is None: break if not has_been_put_back: has_been_put_back = True closure_queue.put_back(closure) continue closure._function() closure_queue.mark_finished() def get_func(label): def func(): time.sleep(3) processed_count[label] += 1 return func cm = cancellation.CancellationManager() for label in labels: closure_queue.put(coordinator_lib.Closure(get_func(label), cm)) t1 = threading.Thread(target=process_queue, daemon=True) t1.start() t2 = threading.Thread(target=process_queue, daemon=True) t2.start() # Make sure multiple wait() calls are fine. closure_queue.wait() closure_queue.wait() closure_queue.wait() closure_queue.wait() self.assertEqual(processed_count, collections.Counter(labels)) coord.join([t1, t2])
def testRemoteFunctionCancellation(self): context._reset_context() logical_devices = [] logical_devices.append(context.LogicalDeviceConfiguration()) logical_devices.append(context.LogicalDeviceConfiguration()) framework_config.set_logical_device_configuration( framework_config.list_physical_devices("CPU")[0], logical_devices) @function.Defun(dtypes.float32) def _remote_fn(v): # We run two collectives here to make sure we cancel in the middle of the # RemoteCall. The second one should never finish. anchor = collective_ops.all_reduce_v2( v, group_size=2, group_key=1, instance_key=1) with ops.control_dependencies([anchor]): return collective_ops.all_reduce_v2( v, group_size=2, group_key=1, instance_key=2) @eager_def_function.function def run(): with ops.device("/cpu:0"): return functional_ops.remote_call( args=[constant_op.constant([1.])], Tout=[dtypes.float32], f=_remote_fn, target="/cpu:1")[0] async_executor = executor.new_executor(enable_async=True) cancel_mgr = cancellation.CancellationManager() with context.executor_scope(async_executor): # This should never finish. cancel_mgr.get_cancelable_function(run.get_concrete_function())() with ops.device("/cpu:0"): collective_ops.all_reduce_v2([1.], group_size=2, group_key=1, instance_key=1) cancel_mgr.start_cancel() with self.assertRaises(errors.CancelledError): async_executor.wait()
def _cancel_all_closures(self): """Clears the queue and sets remaining closures cancelled error. This method expects self._queue_lock to be held prior to entry. """ self._cancellation_mgr.start_cancel() while self._inflight_closure_count > 0: self._no_inflight_closure_condition.wait() while True: try: closure = self._queue.get(block=False) self._queue_free_slot_condition.notify() closure._set_output_remote_values_cancelled() # pylint: disable=protected-access except queue.Empty: break # The cancellation manager cannot be reused once cancelled. After all # closures (queued or inflight) are cleaned up, recreate the cancellation # manager with clean state. # Note on thread-safety: this is triggered when one of theses client APIs # are called: `schedule`, `wait`, and `done`. At the same time, no new # closures can be constructed (which reads the _cancellation_mgr to get # cancellable functions). self._cancellation_mgr = cancellation.CancellationManager()
def __init__(self, cluster_resolver, client_name="chief"): """Initializes the cluster instance and connect to the remote cluster.""" if client_name in ["worker", "ps"]: raise ValueError("Client name should not be 'worker' or 'ps'.") cluster_spec = cluster_resolver.cluster_spec() self._num_workers = len(cluster_spec.as_dict().get("worker", ())) self._num_ps = len(cluster_spec.as_dict().get("ps", ())) device_filters = server_lib.ClusterDeviceFilters() # For any worker, only the devices on PS and chief nodes are visible for i in range(self._num_workers): device_filters.set_device_filters( "worker", i, ["/job:ps", "/job:%s" % client_name]) # Similarly for any ps, only the devices on workers and chief are visible for i in range(self._num_ps): device_filters.set_device_filters( "ps", i, ["/job:worker", "/job:%s" % client_name]) context.context().mirroring_policy = context.MIRRORING_ALL # Allow at most one outstanding RPC for each worker at a certain time. This # is to simplify worker failure handling in the runtime os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False" remote.connect_to_cluster(cluster_spec, job_name=client_name, protocol=cluster_resolver.rpc_layer, cluster_device_filters=device_filters) self._cancellation_mgr = cancellation.CancellationManager() self._closure_queue = _CoordinatedClosureQueue(self._cancellation_mgr) self.failure_handler = WorkerPreemptionHandler(context.get_server_def()) worker_device_strings = [ "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers) ] self.workers = [ Worker(i, w, self) for i, w in enumerate(worker_device_strings) ]
def testStartCancel(self): manager = cancellation.CancellationManager() self.assertFalse(manager.is_cancelled) manager.start_cancel() self.assertTrue(manager.is_cancelled)