Beispiel #1
0
 def testCollectiveGatherShapeMismatch(self):
     group_key = 1
     instance_key = 1
     t0 = [1, 2, 3, 4]
     t1 = [5, 6, 7, 8]
     t2 = [9, 10]
     # Tests that execute collectives need to be enclosed in graph or tf.function
     with ops.Graph().as_default():
         with self.session(config=config_pb2.ConfigProto(
                 device_count={'CPU': 2})) as sess:
             with ops.device('/CPU:0'):
                 in0 = constant_op.constant(t0)
                 c0 = collective_ops.all_gather(in0, 2, group_key,
                                                instance_key)
             with ops.device('/CPU:1'):
                 in1 = constant_op.constant(t1)
                 in2 = constant_op.constant(t2)
                 c1 = collective_ops.all_gather(in1, 2, group_key,
                                                instance_key)
                 c2 = collective_ops.all_gather(in2, 2, group_key,
                                                instance_key)
             run_options = config_pb2.RunOptions()
             run_options.experimental.collective_graph_key = 1
             sess.run([c0, c1], options=run_options)
             with self.assertRaisesRegex(errors.InvalidArgumentError,
                                         'Shape mismatch'):
                 sess.run([c0, c2], options=run_options)
  def testCollectiveGatherPolymorphicShape(self):
    t0 = [0, 1, 2, 3, 4, 5, 6, 7]
    t1 = [10, 11, 12, 13, 14, 15, 16, 17]
    group_size = 2
    group_key = 1
    instance_key = 123
    with self.session(
        config=config_pb2.ConfigProto(
            device_count={'CPU': group_size})) as sess:
      with ops.device('/CPU:0'):
        in0 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
        c0 = collective_ops.all_gather(in0, group_size, group_key, instance_key)
      with ops.device('/CPU:1'):
        in1 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
        c1 = collective_ops.all_gather(in1, group_size, group_key, instance_key)

      results = sess.run([c0, c1], feed_dict={in0: t0, in1: t1})
      expected_output = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17]
      self.assertAllClose(results[0], expected_output, rtol=1e-5, atol=1e-5)
      self.assertAllClose(results[1], expected_output, rtol=1e-5, atol=1e-5)

      results_ = sess.run([c0, c1], feed_dict={in0: t0[1:], in1: t1[1:]})
      expected_output_ = [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17]
      self.assertAllClose(results_[0], expected_output_, rtol=1e-5, atol=1e-5)
      self.assertAllClose(results_[1], expected_output_, rtol=1e-5, atol=1e-5)
Beispiel #3
0
 def testCollectiveGatherShapeMismatch(self):
     group_key = 1
     instance_key = 1
     t0 = [1, 2, 3, 4]
     t1 = [5, 6, 7, 8]
     t2 = [9, 10]
     with self.session(config=config_pb2.ConfigProto(
             device_count={'CPU': 2})) as sess:
         with ops.device('/CPU:0'):
             in0 = constant_op.constant(t0)
             colred0 = collective_ops.all_gather(in0, 2, group_key,
                                                 instance_key)
         with ops.device('/CPU:1'):
             in1 = constant_op.constant(t1)
             in2 = constant_op.constant(t2)
             colred1 = collective_ops.all_gather(in1, 2, group_key,
                                                 instance_key)
             colred2 = collective_ops.all_gather(in2, 2, group_key,
                                                 instance_key)
         run_options = config_pb2.RunOptions()
         run_options.experimental.collective_graph_key = 1
         sess.run([colred0, colred1], options=run_options)
         with self.assertRaisesRegexp(errors.InternalError,
                                      'Inconsistent output shapes'):
             sess.run([colred0, colred2], options=run_options)
Beispiel #4
0
 def all_gather():
     """Use all_gather to aggregate `IndexedSlices`."""
     all_values = collective_ops.all_gather(input_slices.values,
                                            group_size, group_key,
                                            gather_values_key,
                                            communication_hint)
     # Add control dependency to order the all-gather.
     control = [all_values] if communication_hint == 'NCCL' else []
     with ops.control_dependencies(control):
         all_indices = collective_ops.all_gather(
             input_slices.indices, group_size, group_key,
             gather_indices_key, communication_hint)
     return ops.IndexedSlices(values=all_values,
                              indices=all_indices,
                              dense_shape=input_slices.dense_shape)
Beispiel #5
0
    def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
        """All-gather a dense tensor.

    This can be called in eager mode if an async executor is supplied when
    creating the launcher.

    Args:
      input_tensor: a dense tensor. It must have the same shape on all replicas.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      The reduced tensor.
    """
        instance_key = self._next_instance_key()
        ordering_token = self._get_ordering_token(communication_hint)
        with self._executor_scope(), ops.device(self._device):
            if self._use_collective_v2():
                return collective_ops.all_gather_v2(
                    input_tensor,
                    self._group_size,
                    self._group_key,
                    instance_key,
                    communication_hint=communication_hint,
                    timeout=timeout,
                    ordering_token=ordering_token)
            else:
                return collective_ops.all_gather(
                    input_tensor,
                    self._group_size,
                    self._group_key,
                    instance_key,
                    communication_hint=communication_hint,
                    timeout=timeout)
    def testBasicNcclAllGather(self):
        inputs = [[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
                  [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]]
        expected = [
            0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 0.3, 1.3, 2.3, 3.3, 4.3,
            5.3, 6.3, 7.3
        ]
        group_size = len(inputs)
        group_key = 1
        instance_key = 1
        devices = ['/GPU:{}'.format(i) for i in range(group_size)]

        with self.session(config=self._configure(group_size)) as sess:
            if not test_util.is_gpu_available(cuda_only=True):
                self.skipTest('No GPU available')
            collectives = []
            for i in range(group_size):
                with ops.device(devices[i]):
                    t = constant_op.constant(inputs[i])
                    collectives.append(
                        collective_ops.all_gather(t, group_size, group_key,
                                                  instance_key))
            results = sess.run(collectives)
        for result in results:
            self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
Beispiel #7
0
 def _collect_sparse_gradients(self, graph_item, var_op_name):
     """Append collective ops after the gradient is calculated."""
     if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value:
         raise NotImplementedError(
             'Currently the collective NCCL AllGather is not supported in TensorFlow release.'
             'Please choose another strategy.')
     conf = {}
     if self._spec:
         conf = {'communication_hint': self._spec}
     if self._compressor_type:
         logging.warning(
             'AllGather currently does not support AutoDist compressor so it skips.'
         )
     if self.num_replicas * self.num_workers <= 1:
         raise ValueError(
             'CollectiveOps requires collective group size > 1')
     for i in range(0, self.num_replicas):
         op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
         graph_item.updated = True
         grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name]
         # TODO (Tairui): (3) Merge of reduction for performance
         indices_c_ops = grad.indices.consumers()
         indices_cc_ops = get_control_consumers(grad.indices.op)
         values_c_ops = grad.values.consumers()
         values_cc_ops = get_control_consumers(grad.values.op)
         with ops.name_scope(replica_prefix(i)):
             with ops.colocate_with(grad.indices.op):
                 new_indices = collective_ops.all_gather(
                     grad.indices, self.num_replicas * self.num_workers,
                     get_collective_keys().get_group_key(
                         self.all_canonical_replica_devices),
                     get_collective_keys().get_instance_key(var_op_name +
                                                            '-indices'),
                     **conf)
             with ops.colocate_with(grad.values.op):
                 new_values = collective_ops.all_gather(
                     grad.values, self.num_replicas * self.num_workers,
                     get_collective_keys().get_group_key(
                         self.all_canonical_replica_devices),
                     get_collective_keys().get_instance_key(var_op_name +
                                                            '-values'),
                     **conf)
         update_consumers(indices_c_ops, grad.indices, new_indices)
         update_control_consumers(indices_cc_ops, grad.indices.op,
                                  new_indices.op)
         update_consumers(values_c_ops, grad.values, new_values)
         update_control_consumers(values_cc_ops, grad.values.op, new_values)
 def _testCollectiveGather(self, t0, t1, expected, set_graph_key):
   group_key = 1
   instance_key = 1
   with self.session(
       config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
     with ops.device('/CPU:0'):
       in0 = constant_op.constant(t0)
       colred0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
     with ops.device('/CPU:1'):
       in1 = constant_op.constant(t1)
       colred1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
     run_options = config_pb2.RunOptions()
     if set_graph_key:
       run_options.experimental.collective_graph_key = 1
     results = sess.run([colred0, colred1], options=run_options)
   self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
   self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
 def run_basic_nccl_all_gather():
   collectives = []
   for i in range(self._group_size):
     with ops.device(self._devices[i]):
       t = constant_op.constant(inputs[i])
       collectives.append(collective_ops.all_gather(t, self._group_size,
                                                    group_key, instance_key))
   return collectives
Beispiel #10
0
 def _testCollectiveGather(self, t0, t1, expected, set_graph_key):
     group_key = 1
     instance_key = 1
     with self.session(config=config_pb2.ConfigProto(
             device_count={'CPU': 2})) as sess:
         with ops.device('/CPU:0'):
             in0 = constant_op.constant(t0)
             c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
         with ops.device('/CPU:1'):
             in1 = constant_op.constant(t1)
             c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
         run_options = config_pb2.RunOptions()
         if set_graph_key:
             run_options.experimental.collective_graph_key = 1
         results = sess.run([c0, c1], options=run_options)
     self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
     self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
 def testCollectiveGatherShapeMismatchAcrossDevices(self):
     group_key = 1
     instance_key = 1
     t0 = [1, 2, 3, 4]
     t1 = [5, 6]
     with self.session(config=config_pb2.ConfigProto(
             device_count={'CPU': 2})) as sess:
         with ops.device('/CPU:0'):
             in0 = constant_op.constant(t0)
             c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
         with ops.device('/CPU:1'):
             in1 = constant_op.constant(t1)
             c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
         run_options = config_pb2.RunOptions()
         run_options.experimental.collective_graph_key = 1
         with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                      'Shape mismatch'):
             sess.run([c0, c1], options=run_options)
Beispiel #12
0
 def collective_all_gather():
   """Call collective allgather."""
   assert not context.executing_eagerly()
   out_tensors = []
   for d in range(num_devices):
     with ops.device(devices[d]):
       gather_op = collective_ops.all_gather(input_tensors[d], group_size,
                                             group_key, instance_key)
       out_tensors.append(gather_op)
   return out_tensors
def build_collective_gather(input_tensors,
                            devices,
                            group_size,
                            collective_keys,
                            communication_hint='AUTO',
                            control_inputs=None,
                            timeout=None):
    """Build a subgraph that does one full all-gather, using the collective Op.

  This method must be called in graph mode or inside a tf.function.

  Args:
    input_tensors: tensors within a single worker graph that are to be gathered
      together; must be one per device.
    devices: a list of device strings to run the collective on.
    group_size: total number of devices globally that will be doing this same
      gathering. The gathering will actually include the corresponding tensors
      at all these workers.
    collective_keys: a CollectiveKeys object.
    communication_hint: string providing hint to runtime for choosing collective
      implementation.
    control_inputs: if not None, add control edges between control_inputs and
      (index-wise) corresponding collective_gather tensors
    timeout: a float or None. The timeout in seconds.

  Returns:
    An array of final tensors, one per device, computed by the full gather.
  """
    assert not context.executing_eagerly(), (
        'build_collective_gather can only be called in graph mode or inside '
        'tf.function')
    if len(input_tensors) != len(devices):
        raise ValueError(
            'collective requires one input tensor for each device, %d != %d' %
            (len(input_tensors), len(devices)))

    if group_size < 2:
        return input_tensors
    group_key = collective_keys.get_group_key(devices)
    instance_key = collective_keys.get_op_instance_key()

    out_tensors = []
    for idx, input_tensor in enumerate(input_tensors):
        with ops.device(devices[idx]):
            with ops.control_dependencies(
                    _control_input(devices, control_inputs, idx)):
                out_tensor = collective_ops.all_gather(input_tensor,
                                                       group_size,
                                                       group_key,
                                                       instance_key,
                                                       communication_hint,
                                                       timeout=timeout)
            out_tensors.append(out_tensor)
    return out_tensors
  def testCollectiveGroupSizeOne(self):
    group_size = 1
    group_key = 100
    instance_key = 100
    in_value = [1, 2, 3, 4]
    in_tensor = constant_op.constant(in_value)

    reduced_tensor = collective_ops.all_reduce(
        in_tensor, group_size, group_key, instance_key, 'Add', 'Id')
    self.assertAllEqual(in_value, reduced_tensor.numpy())

    gathered_tensor = collective_ops.all_gather(
        in_tensor, group_size, group_key, instance_key)
    self.assertAllEqual(in_value, gathered_tensor.numpy())
def build_collective_gather(input_tensors,
                            num_workers,
                            collective_keys,
                            communication_hint='AUTO',
                            control_inputs=None):
    """Build a subgraph that does one full all-gather, using the collective Op.

  This method must be called in graph mode or inside a tf.function.

  Args:
    input_tensors: tensors within a single worker graph that are to be gathered
      together; must be one per device.
    num_workers: total number of workers with identical independent graphs that
      will be doing this same reduction.  The reduction will actually include
      the corresponding tensors at all these workers.
    collective_keys: a CollectiveKeys object.
    communication_hint: string providing hint to runtime for choosing collective
      implementation.
    control_inputs: if not None, add control edges between control_inputs and
      (index-wise) corresponding collective_gather tensors

  Returns:
    An array of final tensors, one per device, computed by the full gather.
  """
    assert not context.executing_eagerly(), (
        'build_collective_gather can only be called in graph mode or inside '
        'tf.function')

    group_size = len(input_tensors) * num_workers
    if group_size < 2:
        return input_tensors
    group_key = collective_keys.get_group_key_of_tensors(input_tensors)
    instance_key = collective_keys.get_op_instance_key()

    out_tensors = []
    for idx, input_tensor in enumerate(input_tensors):
        with ops.device(input_tensor.device):
            with ops.control_dependencies(
                    _control_input(input_tensors, control_inputs, idx)):
                out_tensor = collective_ops.all_gather(input_tensor,
                                                       group_size, group_key,
                                                       instance_key,
                                                       communication_hint)
            out_tensors.append(out_tensor)
    return out_tensors
  def testCollectiveGroupSizeOne(self):
    self._ensure_context_initialized()

    group_size = 1
    group_key = 100
    instance_key = 100
    in_value = [1., 2., 3., 4.]
    in_tensor = constant_op.constant(in_value)

    with ops.device('/GPU:0'):
      reduced_tensor = collective_ops.all_reduce(
          in_tensor, group_size, group_key, instance_key, 'Add', 'Id',
          communication_hint='nccl')
    self.assertAllEqual(in_value, reduced_tensor.numpy())

    with ops.device('/GPU:0'):
      gathered_tensor = collective_ops.all_gather(
          in_tensor, group_size, group_key, instance_key)
    self.assertAllEqual(in_value, gathered_tensor.numpy())
    def all_reduce_indexed_slices(self,
                                  input_slices,
                                  communication_hint='AUTO',
                                  timeout=0):
        """All-reduce an IndexedSlices.

    This method must be called inside a tf.function.

    Args:
      input_slices: an IndexedSlices.
      communication_hint: string providing hint to runtime for choosing
        collective implementation.
      timeout: a float. The timeout in seconds.

    Returns:
      The reduced IndexedSlices.

    Raises:
      RuntimeError: if called in eager mode.
    """
        if context.executing_eagerly():
            raise RuntimeError(
                'all_reduce_indexed_slices in eager mode is not supported')

        gather_length_key = self._collective_keys.get_instance_key(
            self._group_key, self._device)
        gather_indices_key = self._collective_keys.get_instance_key(
            self._group_key, self._device)
        gather_values_key = self._collective_keys.get_instance_key(
            self._group_key, self._device)
        reduce_densified_key = self._collective_keys.get_instance_key(
            self._group_key, self._device)

        # Current CollectiveAllGather implementations require input IndexedSlices to
        # have consistent length across the board, we handle the reduction of
        # IndexedSlices as follows:
        #   1. Gather the lengths of IndexedSlices from all participants.
        #   2. If they have consistent length, apply all_gather.
        #   3. Otherwise convert IndexedSlices to dense tensors and apply
        #      all_reduce.
        with ops.device(self._device):

            def all_gather():
                """Use all_gather to aggregate `IndexedSlices`."""
                all_values = collective_ops.all_gather(input_slices.values,
                                                       self._group_size,
                                                       self._group_key,
                                                       gather_values_key,
                                                       communication_hint,
                                                       timeout=timeout)
                # Add control dependency to order the all-gather.
                control = [all_values] if communication_hint == 'NCCL' else []
                with ops.control_dependencies(control):
                    all_indices = collective_ops.all_gather(
                        input_slices.indices,
                        self._group_size,
                        self._group_key,
                        gather_indices_key,
                        communication_hint,
                        timeout=timeout)
                return ops.IndexedSlices(values=all_values,
                                         indices=all_indices,
                                         dense_shape=input_slices.dense_shape)

            def densify_and_all_reduce():
                """Use all_reduce to aggregate `IndexedSlices`."""
                densified = ops.convert_to_tensor(input_slices)
                reduced = collective_ops.all_reduce(densified,
                                                    self._group_size,
                                                    self._group_key,
                                                    reduce_densified_key,
                                                    'Add',
                                                    'Id', [0],
                                                    communication_hint,
                                                    timeout=timeout)
                # We have to convert dense grad to IndexedSlice because all_reduce()
                # and all_gather() must have the same return type as required by
                # control_flow_ops.cond.
                return ops.IndexedSlices(values=reduced,
                                         indices=math_ops.range(
                                             array_ops.shape(reduced)[0]),
                                         dense_shape=input_slices.dense_shape)

            length = array_ops.shape(input_slices.indices)
            all_lengths = collective_ops.all_gather(length,
                                                    self._group_size,
                                                    self._group_key,
                                                    gather_length_key,
                                                    communication_hint,
                                                    timeout=timeout)
            return control_flow_ops.cond(
                math_ops.equal(math_ops.reduce_max(all_lengths),
                               math_ops.reduce_min(all_lengths)), all_gather,
                densify_and_all_reduce)
Beispiel #18
0
def build_collective_gather_indexed_slices(input_slices_list,
                                           devices,
                                           group_size,
                                           collective_keys,
                                           communication_hint='AUTO',
                                           control_inputs=None):
    """Build a subgraph that all-gathers IndexedSlices using the collective Op.

  This method must be called in graph mode or inside a tf.function.

  Args:
    input_slices_list: a list of IndexedSlices within a single worker graph that
      are to be gathered together; must be one per device.
    devices: a list of device strings to run the collective on.
    group_size: total number of devices globally that will be doing this same
      gathering. The gathering will actually include the corresponding tensors
      at all these workers.
    collective_keys: a CollectiveKeys object.
    communication_hint: string providing hint to runtime for choosing collective
      implementation.
    control_inputs: if not None, add control edges between control_inputs and
      (index-wise) corresponding collective_reduce tensors

  Returns:
    An array of final IndexedSlices, one per device, computed by the full
    gather.

  Raises:
    ValueError: if control_inputs is not None and doesn't match the length and
      devices of inputs.
  """
    assert not context.executing_eagerly(), (
        'build_collective_gather_indexed_slices can only be called in graph mode'
        ' or inside tf.function')
    if len(input_slices_list) != len(devices):
        raise ValueError(
            'collective requires one input IndexedSlice for each device, %d != %d'
            % (len(input_slices_list), len(devices)))

    if group_size < 2:
        return input_slices_list

    group_key = collective_keys.get_group_key(devices)
    gather_length_key = collective_keys.get_op_instance_key()
    gather_indices_key = collective_keys.get_op_instance_key()
    gather_values_key = collective_keys.get_op_instance_key()
    reduce_densified_key = collective_keys.get_op_instance_key()

    # Current CollectiveAllGather implementations require input IndexedSlices to
    # have consistent length across the board, we handle the reduction of
    # IndexedSlices as follows:
    #   1. Gather the lengths of IndexedSlices from all participants.
    #   2. If they have consistent length, apply all_gather.
    #   3. Otherwise convert IndexedSlices to dense tensors and apply
    #      all_reduce.
    out_slices_list = []
    for idx, input_slices in enumerate(input_slices_list):
        # pylint: disable = cell-var-from-loop
        with ops.device(devices[idx]):

            def all_gather():
                """Use all_gather to aggregate `IndexedSlices`."""
                all_values = collective_ops.all_gather(input_slices.values,
                                                       group_size, group_key,
                                                       gather_values_key,
                                                       communication_hint)
                # Add control dependency to order the all-gather.
                control = [all_values] if communication_hint == 'NCCL' else []
                with ops.control_dependencies(control):
                    all_indices = collective_ops.all_gather(
                        input_slices.indices, group_size, group_key,
                        gather_indices_key, communication_hint)
                return ops.IndexedSlices(values=all_values,
                                         indices=all_indices,
                                         dense_shape=input_slices.dense_shape)

            def densify_and_all_reduce():
                """Use all_reduce to aggregate `IndexedSlices`."""
                densified = ops.convert_to_tensor(input_slices)
                reduced = collective_ops.all_reduce(densified, group_size,
                                                    group_key,
                                                    reduce_densified_key,
                                                    'Add', 'Id', [0],
                                                    communication_hint)
                # We have to convert dense grad to IndexedSlice because all_reduce()
                # and all_gather() must have the same return type as required by
                # control_flow_ops.cond.
                return ops.IndexedSlices(values=reduced,
                                         indices=math_ops.range(
                                             array_ops.shape(reduced)[0]),
                                         dense_shape=input_slices.dense_shape)

            length = array_ops.shape(input_slices.indices)
            with ops.control_dependencies(
                    _control_input(input_slices, control_inputs, idx)):
                all_lengths = collective_ops.all_gather(
                    length, group_size, group_key, gather_length_key,
                    communication_hint)
            out_slices = control_flow_ops.cond(
                math_ops.equal(math_ops.reduce_max(all_lengths),
                               math_ops.reduce_min(all_lengths)), all_gather,
                densify_and_all_reduce)
            out_slices_list.append(out_slices)
        # pylint: enable=cell-var-from-loop
    return out_slices_list
    def all_gather(self,
                   input_tensor,
                   axis,
                   communication_hint='AUTO',
                   timeout=0):
        """All-gather a dense tensor.

    This method must be called inside a tf.function.

    Args:
      input_tensor: a dense tensor. It must have the same rank on all replicas,
        and dimensions other than `axis` need to be the same as well.
      axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
        range [0, rank(value)).
      communication_hint: string providing hint to runtime for choosing
        collective implementation. Available options are `AUTO`, `NCCL`, and
        `RING`.
      timeout: a float. The timeout in seconds.

    Returns:
      The gathered Tensor.

    Raises:
      RuntimeError: if called in eager mode.
    """
        if context.executing_eagerly():
            raise RuntimeError('all_gather in eager mode is not supported')

        instance_key_tensor = self._collective_keys.get_instance_key(
            self._group_key, self._device)
        instance_key_shape = self._collective_keys.get_instance_key(
            self._group_key, self._device)
        with ops.device(self._device):
            # 1. Transpose
            # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
            # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
            # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
            # place it back.
            perm_pre = array_ops.concat(
                ([axis], math_ops.range(axis),
                 math_ops.range(axis + 1, array_ops.rank(input_tensor))),
                axis=0)
            input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
            # 2. Pad
            gathered_shape = collective_ops.all_gather(
                array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t),
                                         axis=0),
                self._group_size,
                self._group_key,
                instance_key_shape,
                communication_hint,
                timeout=timeout)
            first_dims = gathered_shape[:, 0]
            full_axis_dim = math_ops.reduce_max(first_dims)
            padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)

            # 3. Gather
            gather_padded_out_tensor = collective_ops.all_gather(
                padded_input_tensor,
                self._group_size,
                self._group_key,
                instance_key_tensor,
                communication_hint,
                timeout=timeout)
            # 4. Unpad
            split_tensors = []
            for i in range(first_dims.shape[0]):
                start_pos = i * full_axis_dim
                split_tensors.append(
                    gather_padded_out_tensor[start_pos:start_pos +
                                             first_dims[i]])
            out_tensor_t = array_ops.concat(split_tensors, 0)

            # 5. Transpose back
            perm_after = array_ops.concat(
                (math_ops.range(1, axis + 1), [0],
                 math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
                axis=0)
            return array_ops.transpose(out_tensor_t, perm=perm_after)
def build_collective_gather(input_tensors,
                            devices,
                            group_size,
                            collective_keys,
                            axis,
                            communication_hint='AUTO',
                            control_inputs=None,
                            timeout=None):
  """Build a subgraph that does one full all-gather, using the collective Op.

  This method must be called in graph mode or inside a tf.function.

  Args:
    input_tensors: tensors within a single worker graph that are to be gathered
      together; must be one per device. Input tensors cannot have rank 0.
    devices: a list of device strings to run the collective on.
    group_size: total number of devices globally that will be doing this same
      gathering. The gathering will actually include the corresponding tensors
      at all these workers.
    collective_keys: a CollectiveKeys object.
    axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
      range [0, rank(value)).
    communication_hint: string providing hint to runtime for choosing collective
      implementation. Available options are `AUTO`, `NCCL`, and `RING`.
    control_inputs: if not None, add control edges between control_inputs and
      (index-wise) corresponding collective_gather tensors
    timeout: a float or None. The timeout in seconds.

  Returns:
    An array of final tensors, one per device, computed by the full gather.
  """
  if len(input_tensors) != len(devices):
    raise ValueError(
        'collective requires one input tensor for each device, %d != %d' %
        (len(input_tensors), len(devices)))

  if group_size < 2:
    return input_tensors
  group_key = collective_keys.get_group_key(devices)
  instance_key_tensor = collective_keys.get_op_instance_key()
  instance_key_shape = collective_keys.get_op_instance_key()

  out_tensors = []
  for idx, input_tensor in enumerate(input_tensors):
    with ops.device(devices[idx]), ops.control_dependencies(
        _control_input(devices, control_inputs, idx)):
      # 1. Transpose
      # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
      # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
      # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
      # place it back.
      perm_pre = array_ops.concat(
          ([axis], math_ops.range(axis),
           math_ops.range(axis + 1, array_ops.rank(input_tensor))),
          axis=0)
      input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
      # 2. Pad
      gathered_shape = collective_ops.all_gather(
          array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
          group_size,
          group_key,
          instance_key_shape,
          communication_hint,
          timeout=timeout)
      first_dims = gathered_shape[:, 0]
      full_axis_dim = math_ops.reduce_max(first_dims)
      padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)

      # 3. Gather
      gather_padded_out_tensor = collective_ops.all_gather(
          padded_input_tensor,
          group_size,
          group_key,
          instance_key_tensor,
          communication_hint,
          timeout=timeout)
      # 4. Unpad
      split_tensors = []
      for i in range(first_dims.shape[0]):
        start_pos = i * full_axis_dim
        split_tensors.append(gather_padded_out_tensor[start_pos:start_pos +
                                                      first_dims[i]])
      out_tensor_t = array_ops.concat(split_tensors, 0)

      # 5. Transpose back
      perm_after = array_ops.concat(
          (math_ops.range(1, axis + 1), [0],
           math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
          axis=0)
      out_tensor = array_ops.transpose(out_tensor_t, perm=perm_after)
      out_tensors.append(out_tensor)
  return out_tensors