Exemple #1
0
 def _remote_fn(v):
   # We run two collectives here to make sure we cancel in the middle of the
   # RemoteCall. The second one should never finish.
   anchor = collective_ops.all_reduce_v2(
       v, group_size=2, group_key=1, instance_key=1)
   with ops.control_dependencies([anchor]):
     return collective_ops.all_reduce_v2(
         v, group_size=2, group_key=1, instance_key=2)
Exemple #2
0
 def fn(x, y):
     t0 = collective_ops.all_reduce_v2(t=x,
                                       group_size=2,
                                       group_key=1,
                                       instance_key=1)
     t1 = collective_ops.all_reduce_v2(t=y,
                                       group_size=2,
                                       group_key=1,
                                       instance_key=1)
     return t0 + t1
    def all_reduce(self,
                   input_tensor,
                   control_input=None,
                   communication_hint='AUTO',
                   timeout=0):
        """All-reduce a dense tensor.

    Args:
      input_tensor: a dense tensor. It must have the same shape on all replicas.
      control_input: if not None, add control edges between control_input and
        the all-reduce.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      The reduced tensor.
    """
        instance_key = self._next_instance_key()
        ordering_token = self._get_ordering_token(communication_hint)
        with ops.device(self._device), \
             self._control_input(control_input):
            return collective_ops.all_reduce_v2(
                input_tensor,
                self._group_size,
                self._group_key,
                instance_key,
                communication_hint=communication_hint,
                timeout=timeout,
                ordering_token=ordering_token)
Exemple #4
0
    def all_reduce(
            self,
            input_tensor: core.TensorLike,
            control_input: Optional[Union[core.TensorLike,
                                          ops.Operation]] = None,
            options: Optional[collective_util.Options] = None) -> core.Tensor:
        """All-reduce a dense tensor.

    Args:
      input_tensor: a dense tensor. It must have the same shape on all replicas.
      control_input: if not None, add control edges between control_input and
        the all-reduce.
      options: an optional tf.distribute.experimental.CommunicationOptions. If
        provided, it overrides the default options.

    Returns:
      The reduced tensor.
    """
        instance_key = self._next_instance_key()
        options = self._options.merge(options)
        ordering_token = self._get_ordering_token(options)
        with ops.device(self._device), \
             self._control_input(control_input):
            return collective_ops.all_reduce_v2(
                input_tensor,
                self._group_size,
                self._group_key,
                instance_key,
                communication_hint=options.implementation.value,
                timeout=options.timeout_seconds,
                ordering_token=ordering_token)
Exemple #5
0
 def all_reduce_sum(v):
     return collective_ops.all_reduce_v2(t=v,
                                         group_size=2,
                                         group_key=1,
                                         instance_key=1,
                                         merge_op='Add',
                                         final_op='Id')
Exemple #6
0
 def fn(x):
     group_key = collective_ops.assign_group_v2(
         group_assignment=[[0]], device_index=0)
     t0 = collective_ops.all_reduce_v2(t=x,
                                       group_size=1,
                                       group_key=group_key,
                                       instance_key=1)
     return t0
Exemple #7
0
 def _reduce_tensor(tensor):
     with _COUNTER_LOCK:
         global _COUNTER
         keys = _COUNTER
         _COUNTER += 1
     return collective_ops.all_reduce_v2(t=tensor,
                                         group_size=num_replicas,
                                         merge_op=operation,
                                         group_key=keys,
                                         instance_key=keys)
Exemple #8
0
  def testRemoteFunctionCancellation(self):
    context._reset_context()
    logical_devices = []
    logical_devices.append(context.LogicalDeviceConfiguration())
    logical_devices.append(context.LogicalDeviceConfiguration())
    framework_config.set_logical_device_configuration(
        framework_config.list_physical_devices("CPU")[0], logical_devices)

    @function.Defun(dtypes.float32)
    def _remote_fn(v):
      # We run two collectives here to make sure we cancel in the middle of the
      # RemoteCall. The second one should never finish.
      anchor = collective_ops.all_reduce_v2(
          v, group_size=2, group_key=1, instance_key=1)
      with ops.control_dependencies([anchor]):
        return collective_ops.all_reduce_v2(
            v, group_size=2, group_key=1, instance_key=2)

    @eager_def_function.function
    def run():
      with ops.device("/cpu:0"):
        return functional_ops.remote_call(
            args=[constant_op.constant([1.])],
            Tout=[dtypes.float32],
            f=_remote_fn,
            target="/cpu:1")[0]

    async_executor = executor.new_executor(enable_async=True)
    cancel_mgr = cancellation.CancellationManager()
    with context.executor_scope(async_executor):
      # This should never finish.
      cancel_mgr.get_cancelable_function(run.get_concrete_function())()
    with ops.device("/cpu:0"):
      collective_ops.all_reduce_v2([1.],
                                   group_size=2,
                                   group_key=1,
                                   instance_key=1)
    cancel_mgr.start_cancel()
    with self.assertRaises(errors.CancelledError):
      async_executor.wait()
Exemple #9
0
  def all_reduce(self,
                 input_tensor,
                 control_input=None,
                 communication_hint='AUTO',
                 timeout=0):
    """All-reduce a dense tensor.

    This can be called in eager mode if a async executor is supplied when
    creating the launcher.

    Args:
      input_tensor: a dense tensor. It must have the same shape on all replicas.
      control_input: if not None, add control edges between control_input and
        the all-reduce.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      The reduced tensor.
    """
    instance_key = self._next_instance_key()
    ordering_token = self._get_ordering_token(communication_hint)
    with self._executor_scope(), \
         ops.device(self._device), \
         self._control_input(control_input):
      if self._should_use_collective_v2():
        return collective_ops.all_reduce_v2(
            input_tensor,
            self._group_size,
            self._group_key,
            instance_key,
            communication_hint=communication_hint,
            timeout=timeout,
            ordering_token=ordering_token)
      else:
        return collective_ops.all_reduce(
            input_tensor,
            self._group_size,
            self._group_key,
            instance_key,
            communication_hint=communication_hint,
            timeout=timeout)
 def all_reduce(t, group_size, group_key, instance_key, *args, **kwargs):
     group_size = array_ops.identity(group_size)
     group_key = array_ops.identity(group_key)
     instance_key = array_ops.identity(instance_key)
     return _collective_ops.all_reduce_v2(t, group_size, group_key,
                                          instance_key, *args, **kwargs)
 def f():
     return _collective_ops.all_reduce_v2([1.], group_size,
                                          group_key, instance_key)
Exemple #12
0
 def collective_fn(t):
   # Run a dummy collective of group size 1 to test the setup.
   return collective_ops.all_reduce_v2(
       t, group_size=1, group_key=1, instance_key=1)