def testCancelDuringParamResolution(self, collective_op, device,
                                      communication):
    dev0 = '/device:%s:0' % device
    dev1 = '/device:%s:1' % device
    group_size = 2
    group_key = 100
    instance_key = 100
    in_tensor = constant_op.constant([1.])
    t1_cancellation_manager = cancellation.CancellationManager()
    t2_cancellation_manager = cancellation.CancellationManager()

    @def_function.function
    def _collective_fn(x):
      # Run an assertion to crash one of the two function executions running
      # collectives. We explicitly cancel the other in response.
      assert_op = check_ops.assert_equal(x, in_tensor)
      with ops.control_dependencies([assert_op]):
        return collective_op(
            in_tensor,
            group_size,
            group_key,
            instance_key,
            communication_hint=communication)

    collective_concrete = _collective_fn.get_concrete_function(in_tensor)

    finish_mu = threading.Lock()
    finishes = 0

    def _placement_wrapper(device, x, my_cancellation, other_cancellation):
      try:
        with ops.device(device):
          cancelable_collective = my_cancellation.get_cancelable_function(
              collective_concrete)
          return cancelable_collective(x)
      except errors.InvalidArgumentError:
        # `assert_equal` failed for this execution of the function. The other
        # function would deadlock without cancellation.
        other_cancellation.start_cancel()
      except errors.CancelledError:
        pass
      nonlocal finishes
      with finish_mu:
        finishes += 1

    t1 = threading.Thread(
        target=_placement_wrapper,
        args=(dev0, constant_op.constant([1.]), t1_cancellation_manager,
              t2_cancellation_manager))
    t2 = threading.Thread(
        target=_placement_wrapper,
        # Will cause the assertion to fail
        args=(dev1, constant_op.constant([2.]), t2_cancellation_manager,
              t1_cancellation_manager))
    t1.start()
    t2.start()
    t1.join()
    t2.join()
    self.assertEqual(finishes, 2)
예제 #2
0
    def testCancelRemoteFunctionDuringExecution(self):
        remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE'
        default_streaming = os.environ.get(remote_async_env_var)
        os.environ[remote_async_env_var] = str(False)

        q = data_flow_ops.FIFOQueue(1, dtypes.int32)

        @def_function.function
        def f():
            return q.dequeue()

        c_mgr = cancellation.CancellationManager()
        cancelable_func = c_mgr.get_cancelable_function(
            f.get_concrete_function())

        def cancel_thread():
            time.sleep(0.5)
            c_mgr.start_cancel()

        t = self.checkedThread(cancel_thread)
        t.start()
        with self.assertRaises(errors.CancelledError):
            with ops.device('/job:worker/replica:0/task:1'):
                cancelable_func()
        t.join()

        if default_streaming is None:
            del os.environ[remote_async_env_var]
        else:
            os.environ[remote_async_env_var] = default_streaming
예제 #3
0
  def testCancelGetNextWithDevice(self, cls):
    ping = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64)
    pong = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64)

    @def_function.function
    def map_fn(v):
      ball = ping.dequeue()
      with ops.control_dependencies([pong.enqueue(ball)]):
        return v + ping.dequeue()

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.map(map_fn)

    # We need to set prefetch_buffer_size=0 so that we can cancel the
    # MultiDeviceIteratorGetNextFromShardOp from eager. If
    # prefetch_buffer_size>0, that op runs in the background threads of the
    # prefetch and can only be cancelled by deleting the iterator.
    multi_device_iterator = cls(
        dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=0)

    @def_function.function
    def get_next_device1():
      return multi_device_iterator.get_next(self._devices[1])

    async_executor = executor.new_executor(enable_async=True)
    with context.executor_scope(async_executor):
      cancel_mgr = cancellation.CancellationManager()
      cancel_mgr.get_cancelable_function(
          get_next_device1.get_concrete_function())()
    # Make sure we cancel in the middle of get_next.
    ping.enqueue(0)
    pong.dequeue()
    cancel_mgr.start_cancel()
    with self.assertRaises(errors.CancelledError):
      async_executor.wait()
    # Note that fetching from upstream iterator is not cancelled with the
    # cancellation of get_next.
    ping.enqueue(0)

    # Cancelling a get_next on one device shouldn't cancel the
    # multi_device_iterator and iterators on other devices.
    ping.enqueue(0)
    ping.enqueue(0)
    self.assertEqual(1,
                     multi_device_iterator.get_next(self._devices[2]).numpy())
    # FIXME(b/209534797): Workaround an asan error caused by this test.
    # Remove the dangling reference from tf.function to ensure queue objects
    # are not freed before they are flushed.
    import gc  # pylint: disable=g-import-not-at-top
    del get_next_device1
    gc.collect()
예제 #4
0
  def __init__(self):
    # `self._inflight_closure_count` only tracks the number of inflight closures
    # that are "in generation". Once an error occurs, error generation is
    # incremented and all subsequent arriving closures (from inflight) are
    # considered "out of generation".
    self._inflight_closure_count = 0

    self._queue_lock = threading.Lock()

    # Condition indicating that all pending closures (either queued or inflight)
    # have been processed, failed, or cancelled.
    self._stop_waiting_condition = threading.Condition(self._queue_lock)

    # Condition indicating that an item becomes available in queue (not empty).
    self._closures_queued_condition = threading.Condition(self._queue_lock)

    # Condition indicating that a queue slot becomes available (not full).
    # Note that even with "infinite" queue size, there is still a "practical"
    # size limit for the queue depending on host memory capacity, and thus the
    # queue will eventually become full with a lot of enqueued closures.
    self._queue_free_slot_condition = threading.Condition(self._queue_lock)

    # Condition indicating there is no inflight closures.
    self._no_inflight_closure_condition = threading.Condition(self._queue_lock)

    # Use to cancel in-flight closures.
    self._cancellation_mgr = cancellation.CancellationManager()

    if _CLOSURE_QUEUE_MAX_SIZE <= 0:
      logging.warning(
          "In a `Client`, creating an infinite closure queue can "
          "consume a significant amount of memory and even lead to OOM.")
    self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
    self._error = None

    # The following is a lock to make sure when `wait` is called and before it
    # returns no `put` can be executed during this period. It is because `wait`
    # won't know what to do with newly put closures. This lock adds an cutoff
    # for `wait` so that closures put into the queue while waiting would not be
    # taken responsible by this `wait`.
    #
    # We cannot reuse the `self._queue_lock` since when `wait` waits for a
    # condition, the `self._queue_lock` will be released.
    #
    # We don't use a reader/writer's lock on purpose to reduce the complexity
    # of the code.
    self._put_wait_lock = threading.Lock()
  def testProcessAtLeaseOnce(self):
    closure_queue = coordinator_lib._CoordinatedClosureQueue()
    labels = ['A', 'B', 'C', 'D', 'E']
    processed_count = collections.defaultdict(int)

    coord = coordinator.Coordinator(clean_stop_exception_types=[])

    def process_queue():
      with coord.stop_on_exception():
        has_been_put_back = False
        while True:
          closure = closure_queue.get(timeout=30)
          if closure is None:
            break
          if not has_been_put_back:
            has_been_put_back = True
            closure_queue.put_back(closure)
            continue
          closure._function()
          closure_queue.mark_finished()

    def get_func(label):

      def func():
        time.sleep(3)
        processed_count[label] += 1

      return func

    cm = cancellation.CancellationManager()
    for label in labels:
      closure_queue.put(coordinator_lib.Closure(get_func(label), cm))
    t1 = threading.Thread(target=process_queue, daemon=True)
    t1.start()
    t2 = threading.Thread(target=process_queue, daemon=True)
    t2.start()

    # Make sure multiple wait() calls are fine.
    closure_queue.wait()
    closure_queue.wait()
    closure_queue.wait()
    closure_queue.wait()

    self.assertEqual(processed_count, collections.Counter(labels))

    coord.join([t1, t2])
예제 #6
0
  def testRemoteFunctionCancellation(self):
    context._reset_context()
    logical_devices = []
    logical_devices.append(context.LogicalDeviceConfiguration())
    logical_devices.append(context.LogicalDeviceConfiguration())
    framework_config.set_logical_device_configuration(
        framework_config.list_physical_devices("CPU")[0], logical_devices)

    @function.Defun(dtypes.float32)
    def _remote_fn(v):
      # We run two collectives here to make sure we cancel in the middle of the
      # RemoteCall. The second one should never finish.
      anchor = collective_ops.all_reduce_v2(
          v, group_size=2, group_key=1, instance_key=1)
      with ops.control_dependencies([anchor]):
        return collective_ops.all_reduce_v2(
            v, group_size=2, group_key=1, instance_key=2)

    @eager_def_function.function
    def run():
      with ops.device("/cpu:0"):
        return functional_ops.remote_call(
            args=[constant_op.constant([1.])],
            Tout=[dtypes.float32],
            f=_remote_fn,
            target="/cpu:1")[0]

    async_executor = executor.new_executor(enable_async=True)
    cancel_mgr = cancellation.CancellationManager()
    with context.executor_scope(async_executor):
      # This should never finish.
      cancel_mgr.get_cancelable_function(run.get_concrete_function())()
    with ops.device("/cpu:0"):
      collective_ops.all_reduce_v2([1.],
                                   group_size=2,
                                   group_key=1,
                                   instance_key=1)
    cancel_mgr.start_cancel()
    with self.assertRaises(errors.CancelledError):
      async_executor.wait()
예제 #7
0
  def _cancel_all_closures(self):
    """Clears the queue and sets remaining closures cancelled error.

    This method expects self._queue_lock to be held prior to entry.
    """
    self._cancellation_mgr.start_cancel()
    while self._inflight_closure_count > 0:
      self._no_inflight_closure_condition.wait()
    while True:
      try:
        closure = self._queue.get(block=False)
        self._queue_free_slot_condition.notify()
        closure._set_output_remote_values_cancelled()  # pylint: disable=protected-access
      except queue.Empty:
        break
    # The cancellation manager cannot be reused once cancelled. After all
    # closures (queued or inflight) are cleaned up, recreate the cancellation
    # manager with clean state.
    # Note on thread-safety: this is triggered when one of theses client APIs
    # are called: `schedule`, `wait`, and `done`. At the same time, no new
    # closures can be constructed (which reads the _cancellation_mgr to get
    # cancellable functions).
    self._cancellation_mgr = cancellation.CancellationManager()
예제 #8
0
  def __init__(self, cluster_resolver, client_name="chief"):
    """Initializes the cluster instance and connect to the remote cluster."""
    if client_name in ["worker", "ps"]:
      raise ValueError("Client name should not be 'worker' or 'ps'.")
    cluster_spec = cluster_resolver.cluster_spec()

    self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
    self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
    device_filters = server_lib.ClusterDeviceFilters()
    # For any worker, only the devices on PS and chief nodes are visible
    for i in range(self._num_workers):
      device_filters.set_device_filters(
          "worker", i, ["/job:ps", "/job:%s" % client_name])
    # Similarly for any ps, only the devices on workers and chief are visible
    for i in range(self._num_ps):
      device_filters.set_device_filters(
          "ps", i, ["/job:worker", "/job:%s" % client_name])

    context.context().mirroring_policy = context.MIRRORING_ALL
    # Allow at most one outstanding RPC for each worker at a certain time. This
    # is to simplify worker failure handling in the runtime
    os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False"
    remote.connect_to_cluster(cluster_spec,
                              job_name=client_name,
                              protocol=cluster_resolver.rpc_layer,
                              cluster_device_filters=device_filters)

    self._cancellation_mgr = cancellation.CancellationManager()
    self._closure_queue = _CoordinatedClosureQueue(self._cancellation_mgr)
    self.failure_handler = WorkerPreemptionHandler(context.get_server_def())
    worker_device_strings = [
        "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
    ]
    self.workers = [
        Worker(i, w, self) for i, w in enumerate(worker_device_strings)
    ]
예제 #9
0
 def testStartCancel(self):
     manager = cancellation.CancellationManager()
     self.assertFalse(manager.is_cancelled)
     manager.start_cancel()
     self.assertTrue(manager.is_cancelled)