def all_gather(t, group_size, group_key, instance_key): """Accumulates tensors collectively, across devices, along first dimension. Args: t: the tensor to participate in the accumulation. group_size: the total number of tensors to be collectively accumulated. Each must reside on a different device. group_key: an integer identifying the group of devices. instance_key: an integer identifying the participating group of Ops. Returns: An Op implementing the distributed operation. Raises: ValueError: if any of the input parameter constraints are not met. """ if not device.canonical_name(t.device): raise ValueError('Device assignment required for collective ops') if group_size <= 1: raise ValueError( 'Parameter group_size to all_gather must be at least 2.') dims = t.shape.as_list() output_shape = [dims[0] * group_size] + dims[1:] return gen_collective_ops.collective_gather(t, shape=output_shape, group_size=group_size, group_key=group_key, instance_key=instance_key)
def all_gather(t, group_size, group_key, instance_key, communication_hint='auto'): """Accumulates tensors collectively, across devices, along first dimension. Args: t: the tensor to participate in the accumulation. group_size: the total number of tensors to be collectively accumulated. Each must reside on a different device. Should be a positive integer. group_key: an integer identifying the group of devices. instance_key: an integer identifying the participating group of Ops. communication_hint: preferred collective communication. The implementation may fall back to another mechanism. Options include `auto`, `ring`, and `nccl`. Returns: An Op implementing the distributed operation. Raises: ValueError: if any of the input parameter constraints are not met. """ if group_size < 1: raise ValueError( 'Parameter group_size to all_gather must be at least 1.') return gen_collective_ops.collective_gather( t, shape=[0], group_size=group_size, group_key=group_key, instance_key=instance_key, communication_hint=communication_hint.lower())
def all_gather_v2(t, group_size, group_key, instance_key, communication_hint='auto', timeout=0): """Accumulates tensors collectively, across devices, along first dimension. Args: t: the tensor to participate in the accumulation. group_size: an int32 tensor, the total number of tensors to be collectively accumulated. Each must reside on a different device. Should be a positive integer. group_key: an int32 tensor identifying the group of devices. instance_key: an int32 tensor identifying the participating group of Ops. communication_hint: preferred collective communication. The implementation may fall back to another mechanism. Options include `auto`, `ring`, and `nccl`. timeout: a float. If set to a non zero, set a completion timeout to detect staleness. If the timer goes off, a DeadlineExceededError is raised. The timeout value in seconds. This feature is experimental. Returns: An Op implementing the distributed operation. """ return gen_collective_ops.collective_gather( t, shape=[0], group_size=group_size, group_key=group_key, instance_key=instance_key, communication_hint=communication_hint.lower(), timeout_seconds=timeout)
def all_gather(t, group_size, group_key, instance_key): """Accumulates tensors collectively, across devices, along first dimension. Args: t: the tensor to participate in the accumulation. group_size: the total number of tensors to be collectively accumulated. Each must reside on a different device. group_key: an integer identifying the group of devices. instance_key: an integer identifying the participating group of Ops. Returns: An Op implementing the distributed operation. Raises: ValueError: if any of the input parameter constraints are not met. """ if not device.canonical_name(t.device): raise ValueError('Device assignment required for collective ops') if group_size <= 1: raise ValueError('Parameter group_size to all_gather must be at least 2.') dims = t.shape.as_list() output_shape = [dims[0] * group_size] + dims[1:] return gen_collective_ops.collective_gather(t, shape=output_shape, group_size=group_size, group_key=group_key, instance_key=instance_key)
def all_gather(t, group_size, group_key, instance_key, communication_hint='auto', timeout=0, ordering_token=None): """Accumulates tensors collectively, across devices, along first dimension. Args: t: the tensor to participate in the accumulation. group_size: the total number of tensors to be collectively accumulated. Each must reside on a different device. Should be a positive integer. group_key: an integer identifying the group of devices. instance_key: an integer identifying the participating group of Ops. communication_hint: preferred collective communication. The implementation may fall back to another mechanism. Options include `auto`, `ring`, and `nccl`. timeout: a float. If set to a non zero, set a completion timeout to detect staleness. If the timer goes off, a DeadlineExceededError is raised. The timeout value in seconds. This feature is experimental. ordering_token: an optional resource tensor to pass to the op as inputs. They aren't used by the kernel but allow AutoControlDependency to order the collectives with control dependencies. Returns: An Op implementing the distributed operation. Raises: ValueError: if any of the input parameter constraints are not met. """ if group_size < 1: raise ValueError('Parameter group_size to all_gather must be at least 1.') if ordering_token is not None: ordering_token = [ordering_token] return gen_collective_ops.collective_gather( t, shape=[0], group_size=group_size, group_key=group_key, instance_key=instance_key, communication_hint=communication_hint.lower(), timeout_seconds=timeout, ordering_token=ordering_token or [])