示例#1
0
    def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
        """All-gather a dense tensor.

    This can be called in eager mode if an async executor is supplied when
    creating the launcher.

    Args:
      input_tensor: a dense tensor. It must have the same shape on all replicas.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      The reduced tensor.
    """
        instance_key = self._next_instance_key()
        ordering_token = self._get_ordering_token(communication_hint)
        with self._executor_scope(), ops.device(self._device):
            if self._use_collective_v2():
                return collective_ops.all_gather_v2(
                    input_tensor,
                    self._group_size,
                    self._group_key,
                    instance_key,
                    communication_hint=communication_hint,
                    timeout=timeout,
                    ordering_token=ordering_token)
            else:
                return collective_ops.all_gather(
                    input_tensor,
                    self._group_size,
                    self._group_key,
                    instance_key,
                    communication_hint=communication_hint,
                    timeout=timeout)
示例#2
0
    def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
        """All-gather a dense tensor.

    Args:
      input_tensor: a dense tensor. It must have the same shape on all replicas.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      The reduced tensor.
    """
        instance_key = self._next_instance_key()
        ordering_token = self._get_ordering_token(communication_hint)
        with ops.device(self._device):
            return collective_ops.all_gather_v2(
                input_tensor,
                self._group_size,
                self._group_key,
                instance_key,
                communication_hint=communication_hint,
                timeout=timeout,
                ordering_token=ordering_token)
示例#3
0
    def _all_gather(self, input_tensor: core.TensorLike,
                    options: Optional[collective_util.Options]) -> core.Tensor:
        """All-gather a dense tensor.

    Args:
      input_tensor: a dense tensor. It must have the same shape on all replicas.
      options: an optional tf.distribute.experimental.CommunicationOptions. If
        provided, it overrides the default options.

    Returns:
      The reduced tensor.
    """
        instance_key = self._next_instance_key()
        options = self._options.merge(options)
        ordering_token = self._get_ordering_token(options)
        with ops.device(self._device):
            return collective_ops.all_gather_v2(
                input_tensor,
                self._group_size,
                self._group_key,
                instance_key,
                communication_hint=options.implementation.value,
                timeout=options.timeout_seconds,
                ordering_token=ordering_token)
 def all_gather(t, group_size, group_key, instance_key, *args, **kwargs):
     group_size = array_ops.identity(group_size)
     group_key = array_ops.identity(group_key)
     instance_key = array_ops.identity(instance_key)
     return _collective_ops.all_gather_v2(t, group_size, group_key,
                                          instance_key, *args, **kwargs)