예제 #1
0
    def testPyFunctionAsync(self):
        def simple_fn(v):
            one = constant_op.constant(1.)
            return v + one

        @def_function.function
        def test_fn(v):
            return script_ops.eager_py_func(simple_fn, [v], dtypes.float32)

        async_executor = executor.new_executor(enable_async=True)
        with context.executor_scope(async_executor):
            test_var = variables.Variable(2.)
            self.assertAllEqual(test_fn(test_var), 3.0)
        async_executor.wait()
    def testCancelGetNextWithDevice(self, cls):
        ping = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64)
        pong = data_flow_ops.FIFOQueue(capacity=2, dtypes=dtypes.int64)

        @def_function.function
        def map_fn(v):
            ball = ping.dequeue()
            with ops.control_dependencies([pong.enqueue(ball)]):
                return v + ping.dequeue()

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.map(map_fn)

        # We need to set prefetch_buffer_size=0 so that we can cancel the
        # MultiDeviceIteratorGetNextFromShardOp from eager. If
        # prefetch_buffer_size>0, that op runs in the background threads of the
        # prefetch and can only be cancelled by deleting the iterator.
        multi_device_iterator = cls(dataset,
                                    [self._devices[1], self._devices[2]],
                                    prefetch_buffer_size=0)

        @def_function.function
        def get_next_device1():
            return multi_device_iterator.get_next(self._devices[1])

        async_executor = executor.new_executor(enable_async=True)
        with context.executor_scope(async_executor):
            cancel_mgr = cancellation.CancellationManager()
            cancel_mgr.get_cancelable_function(
                get_next_device1.get_concrete_function())()
        # Make sure we cancel in the middle of get_next.
        ping.enqueue(0)
        pong.dequeue()
        cancel_mgr.start_cancel()
        with self.assertRaises(errors.CancelledError):
            async_executor.wait()
        # Note that fetching from upstream iterator is not cancelled with the
        # cancellation of get_next.
        ping.enqueue(0)

        # Cancelling a get_next on one device shouldn't cancel the
        # multi_device_iterator and iterators on other devices.
        ping.enqueue(0)
        ping.enqueue(0)
        self.assertEqual(
            1,
            multi_device_iterator.get_next(self._devices[2]).numpy())
예제 #3
0
  def testRemoteFunctionCancellation(self):
    context._reset_context()
    logical_devices = []
    logical_devices.append(context.LogicalDeviceConfiguration())
    logical_devices.append(context.LogicalDeviceConfiguration())
    framework_config.set_logical_device_configuration(
        framework_config.list_physical_devices("CPU")[0], logical_devices)

    @function.Defun(dtypes.float32)
    def _remote_fn(v):
      # We run two collectives here to make sure we cancel in the middle of the
      # RemoteCall. The second one should never finish.
      anchor = collective_ops.all_reduce_v2(
          v, group_size=2, group_key=1, instance_key=1)
      with ops.control_dependencies([anchor]):
        return collective_ops.all_reduce_v2(
            v, group_size=2, group_key=1, instance_key=2)

    @eager_def_function.function
    def run():
      with ops.device("/cpu:0"):
        return functional_ops.remote_call(
            args=[constant_op.constant([1.])],
            Tout=[dtypes.float32],
            f=_remote_fn,
            target="/cpu:1")[0]

    async_executor = executor.new_executor(enable_async=True)
    cancel_mgr = cancellation.CancellationManager()
    with context.executor_scope(async_executor):
      # This should never finish.
      cancel_mgr.get_cancelable_function(run.get_concrete_function())()
    with ops.device("/cpu:0"):
      collective_ops.all_reduce_v2([1.],
                                   group_size=2,
                                   group_key=1,
                                   instance_key=1)
    cancel_mgr.start_cancel()
    with self.assertRaises(errors.CancelledError):
      async_executor.wait()
예제 #4
0
 def thread_fn(executor_obj, device, results):
   with self._coord.stop_on_exception():
     for i in range(num_calls):
       with context.executor_scope(executor_obj):
         with ops.device(device):
           results[i] = worker_fn()
예제 #5
0
def build_collective_reduce(input_tensors,
                            devices,
                            group_size,
                            collective_keys,
                            reduction_op='Add',
                            unary_op='Id',
                            communication_hint='AUTO',
                            control_inputs=None,
                            executors=None):
    """Build a subgraph that does one full all-reduce, using the collective Op.

  If called in eager mode, it's required to supply a list of async executors for
  each input Tensor.

  Args:
    input_tensors: tensors within a single worker graph that are to be reduced
      together; must be one per device.
    devices: a list of device strings to run the collective on.
    group_size: total number of devices globally that will be doing this same
      reduction.  The reduction will actually include the corresponding tensors
      at all these workers.
    collective_keys: a CollectiveKeys object.
    reduction_op: string naming the reduction op.
    unary_op: string naming the unary final op.
    communication_hint: string providing hint to runtime for choosing collective
      implementation.
    control_inputs: if not None, add control edges between control_inputs and
      (index-wise) corresponding collective_reduce tensors
    executors: a list of async executor. Required for eager execution.

  Returns:
    An array of final tensors, one per device, computed by the full reduction.

  Raises:
    ValueError: There must be at least two tensors over all the workers.
  """
    if context.executing_eagerly():
        if (not executors or len(executors) != len(input_tensors)
                or not all(e.is_async() for e in executors)):
            raise ValueError(
                'collectives requires async executors for each device in eager mode'
            )
    if len(input_tensors) != len(devices):
        raise ValueError(
            'collective requires one input tensor for each device, '
            'len(input_tensors) = %d, len(devices) = %d' %
            (len(input_tensors), len(devices)))

    if group_size < 2:
        return input_tensors
    group_key = collective_keys.get_group_key(devices)
    instance_key = collective_keys.get_op_instance_key()
    subdiv_offsets = [0]  # TODO(tucker): maybe support non-default subdiv spec

    out_tensors = []
    for idx, input_tensor in enumerate(input_tensors):
        if context.executing_eagerly():
            executor_scope = context.executor_scope(executors[idx])
        else:
            executor_scope = ops.NullContextmanager()
        with executor_scope, \
             ops.device(devices[idx]), \
             ops.control_dependencies(
                 _control_input(devices, control_inputs, idx)):
            out_tensor = collective_ops.all_reduce(input_tensor, group_size,
                                                   group_key, instance_key,
                                                   reduction_op, unary_op,
                                                   subdiv_offsets,
                                                   communication_hint)
        out_tensors.append(out_tensor)
    return out_tensors
예제 #6
0
 def _executor_scope(self):
   if context.executing_eagerly() and not self._executor:
     raise ValueError('collectives requires a async executor in eager mode')
   if context.executing_eagerly():
     return context.executor_scope(self._executor)
   return ops.NullContextmanager()