def testCollectiveGatherShapeMismatch(self): group_key = 1 instance_key = 1 t0 = [1, 2, 3, 4] t1 = [5, 6, 7, 8] t2 = [9, 10] # Tests that execute collectives need to be enclosed in graph or tf.function with ops.Graph().as_default(): with self.session(config=config_pb2.ConfigProto( device_count={'CPU': 2})) as sess: with ops.device('/CPU:0'): in0 = constant_op.constant(t0) c0 = collective_ops.all_gather(in0, 2, group_key, instance_key) with ops.device('/CPU:1'): in1 = constant_op.constant(t1) in2 = constant_op.constant(t2) c1 = collective_ops.all_gather(in1, 2, group_key, instance_key) c2 = collective_ops.all_gather(in2, 2, group_key, instance_key) run_options = config_pb2.RunOptions() run_options.experimental.collective_graph_key = 1 sess.run([c0, c1], options=run_options) with self.assertRaisesRegex(errors.InvalidArgumentError, 'Shape mismatch'): sess.run([c0, c2], options=run_options)
def testCollectiveGatherPolymorphicShape(self): t0 = [0, 1, 2, 3, 4, 5, 6, 7] t1 = [10, 11, 12, 13, 14, 15, 16, 17] group_size = 2 group_key = 1 instance_key = 123 with self.session( config=config_pb2.ConfigProto( device_count={'CPU': group_size})) as sess: with ops.device('/CPU:0'): in0 = array_ops.placeholder(dtype=dtypes.int32, shape=[None]) c0 = collective_ops.all_gather(in0, group_size, group_key, instance_key) with ops.device('/CPU:1'): in1 = array_ops.placeholder(dtype=dtypes.int32, shape=[None]) c1 = collective_ops.all_gather(in1, group_size, group_key, instance_key) results = sess.run([c0, c1], feed_dict={in0: t0, in1: t1}) expected_output = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17] self.assertAllClose(results[0], expected_output, rtol=1e-5, atol=1e-5) self.assertAllClose(results[1], expected_output, rtol=1e-5, atol=1e-5) results_ = sess.run([c0, c1], feed_dict={in0: t0[1:], in1: t1[1:]}) expected_output_ = [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17] self.assertAllClose(results_[0], expected_output_, rtol=1e-5, atol=1e-5) self.assertAllClose(results_[1], expected_output_, rtol=1e-5, atol=1e-5)
def testCollectiveGatherShapeMismatch(self): group_key = 1 instance_key = 1 t0 = [1, 2, 3, 4] t1 = [5, 6, 7, 8] t2 = [9, 10] with self.session(config=config_pb2.ConfigProto( device_count={'CPU': 2})) as sess: with ops.device('/CPU:0'): in0 = constant_op.constant(t0) colred0 = collective_ops.all_gather(in0, 2, group_key, instance_key) with ops.device('/CPU:1'): in1 = constant_op.constant(t1) in2 = constant_op.constant(t2) colred1 = collective_ops.all_gather(in1, 2, group_key, instance_key) colred2 = collective_ops.all_gather(in2, 2, group_key, instance_key) run_options = config_pb2.RunOptions() run_options.experimental.collective_graph_key = 1 sess.run([colred0, colred1], options=run_options) with self.assertRaisesRegexp(errors.InternalError, 'Inconsistent output shapes'): sess.run([colred0, colred2], options=run_options)
def all_gather(): """Use all_gather to aggregate `IndexedSlices`.""" all_values = collective_ops.all_gather(input_slices.values, group_size, group_key, gather_values_key, communication_hint) # Add control dependency to order the all-gather. control = [all_values] if communication_hint == 'NCCL' else [] with ops.control_dependencies(control): all_indices = collective_ops.all_gather( input_slices.indices, group_size, group_key, gather_indices_key, communication_hint) return ops.IndexedSlices(values=all_values, indices=all_indices, dense_shape=input_slices.dense_shape)
def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0): """All-gather a dense tensor. This can be called in eager mode if an async executor is supplied when creating the launcher. Args: input_tensor: a dense tensor. It must have the same shape on all replicas. communication_hint: string providing hint to runtime for choosing collective implementation. timeout: a float. The timeout in seconds. Returns: The reduced tensor. """ instance_key = self._next_instance_key() ordering_token = self._get_ordering_token(communication_hint) with self._executor_scope(), ops.device(self._device): if self._use_collective_v2(): return collective_ops.all_gather_v2( input_tensor, self._group_size, self._group_key, instance_key, communication_hint=communication_hint, timeout=timeout, ordering_token=ordering_token) else: return collective_ops.all_gather( input_tensor, self._group_size, self._group_key, instance_key, communication_hint=communication_hint, timeout=timeout)
def testBasicNcclAllGather(self): inputs = [[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]] expected = [ 0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3 ] group_size = len(inputs) group_key = 1 instance_key = 1 devices = ['/GPU:{}'.format(i) for i in range(group_size)] with self.session(config=self._configure(group_size)) as sess: if not test_util.is_gpu_available(cuda_only=True): self.skipTest('No GPU available') collectives = [] for i in range(group_size): with ops.device(devices[i]): t = constant_op.constant(inputs[i]) collectives.append( collective_ops.all_gather(t, group_size, group_key, instance_key)) results = sess.run(collectives) for result in results: self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
def _collect_sparse_gradients(self, graph_item, var_op_name): """Append collective ops after the gradient is calculated.""" if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value: raise NotImplementedError( 'Currently the collective NCCL AllGather is not supported in TensorFlow release.' 'Please choose another strategy.') conf = {} if self._spec: conf = {'communication_hint': self._spec} if self._compressor_type: logging.warning( 'AllGather currently does not support AutoDist compressor so it skips.' ) if self.num_replicas * self.num_workers <= 1: raise ValueError( 'CollectiveOps requires collective group size > 1') for i in range(0, self.num_replicas): op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i)) graph_item.updated = True grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name] # TODO (Tairui): (3) Merge of reduction for performance indices_c_ops = grad.indices.consumers() indices_cc_ops = get_control_consumers(grad.indices.op) values_c_ops = grad.values.consumers() values_cc_ops = get_control_consumers(grad.values.op) with ops.name_scope(replica_prefix(i)): with ops.colocate_with(grad.indices.op): new_indices = collective_ops.all_gather( grad.indices, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-indices'), **conf) with ops.colocate_with(grad.values.op): new_values = collective_ops.all_gather( grad.values, self.num_replicas * self.num_workers, get_collective_keys().get_group_key( self.all_canonical_replica_devices), get_collective_keys().get_instance_key(var_op_name + '-values'), **conf) update_consumers(indices_c_ops, grad.indices, new_indices) update_control_consumers(indices_cc_ops, grad.indices.op, new_indices.op) update_consumers(values_c_ops, grad.values, new_values) update_control_consumers(values_cc_ops, grad.values.op, new_values)
def _testCollectiveGather(self, t0, t1, expected, set_graph_key): group_key = 1 instance_key = 1 with self.session( config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess: with ops.device('/CPU:0'): in0 = constant_op.constant(t0) colred0 = collective_ops.all_gather(in0, 2, group_key, instance_key) with ops.device('/CPU:1'): in1 = constant_op.constant(t1) colred1 = collective_ops.all_gather(in1, 2, group_key, instance_key) run_options = config_pb2.RunOptions() if set_graph_key: run_options.experimental.collective_graph_key = 1 results = sess.run([colred0, colred1], options=run_options) self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5) self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
def run_basic_nccl_all_gather(): collectives = [] for i in range(self._group_size): with ops.device(self._devices[i]): t = constant_op.constant(inputs[i]) collectives.append(collective_ops.all_gather(t, self._group_size, group_key, instance_key)) return collectives
def _testCollectiveGather(self, t0, t1, expected, set_graph_key): group_key = 1 instance_key = 1 with self.session(config=config_pb2.ConfigProto( device_count={'CPU': 2})) as sess: with ops.device('/CPU:0'): in0 = constant_op.constant(t0) c0 = collective_ops.all_gather(in0, 2, group_key, instance_key) with ops.device('/CPU:1'): in1 = constant_op.constant(t1) c1 = collective_ops.all_gather(in1, 2, group_key, instance_key) run_options = config_pb2.RunOptions() if set_graph_key: run_options.experimental.collective_graph_key = 1 results = sess.run([c0, c1], options=run_options) self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5) self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
def testCollectiveGatherShapeMismatchAcrossDevices(self): group_key = 1 instance_key = 1 t0 = [1, 2, 3, 4] t1 = [5, 6] with self.session(config=config_pb2.ConfigProto( device_count={'CPU': 2})) as sess: with ops.device('/CPU:0'): in0 = constant_op.constant(t0) c0 = collective_ops.all_gather(in0, 2, group_key, instance_key) with ops.device('/CPU:1'): in1 = constant_op.constant(t1) c1 = collective_ops.all_gather(in1, 2, group_key, instance_key) run_options = config_pb2.RunOptions() run_options.experimental.collective_graph_key = 1 with self.assertRaisesRegexp(errors.InvalidArgumentError, 'Shape mismatch'): sess.run([c0, c1], options=run_options)
def collective_all_gather(): """Call collective allgather.""" assert not context.executing_eagerly() out_tensors = [] for d in range(num_devices): with ops.device(devices[d]): gather_op = collective_ops.all_gather(input_tensors[d], group_size, group_key, instance_key) out_tensors.append(gather_op) return out_tensors
def build_collective_gather(input_tensors, devices, group_size, collective_keys, communication_hint='AUTO', control_inputs=None, timeout=None): """Build a subgraph that does one full all-gather, using the collective Op. This method must be called in graph mode or inside a tf.function. Args: input_tensors: tensors within a single worker graph that are to be gathered together; must be one per device. devices: a list of device strings to run the collective on. group_size: total number of devices globally that will be doing this same gathering. The gathering will actually include the corresponding tensors at all these workers. collective_keys: a CollectiveKeys object. communication_hint: string providing hint to runtime for choosing collective implementation. control_inputs: if not None, add control edges between control_inputs and (index-wise) corresponding collective_gather tensors timeout: a float or None. The timeout in seconds. Returns: An array of final tensors, one per device, computed by the full gather. """ assert not context.executing_eagerly(), ( 'build_collective_gather can only be called in graph mode or inside ' 'tf.function') if len(input_tensors) != len(devices): raise ValueError( 'collective requires one input tensor for each device, %d != %d' % (len(input_tensors), len(devices))) if group_size < 2: return input_tensors group_key = collective_keys.get_group_key(devices) instance_key = collective_keys.get_op_instance_key() out_tensors = [] for idx, input_tensor in enumerate(input_tensors): with ops.device(devices[idx]): with ops.control_dependencies( _control_input(devices, control_inputs, idx)): out_tensor = collective_ops.all_gather(input_tensor, group_size, group_key, instance_key, communication_hint, timeout=timeout) out_tensors.append(out_tensor) return out_tensors
def testCollectiveGroupSizeOne(self): group_size = 1 group_key = 100 instance_key = 100 in_value = [1, 2, 3, 4] in_tensor = constant_op.constant(in_value) reduced_tensor = collective_ops.all_reduce( in_tensor, group_size, group_key, instance_key, 'Add', 'Id') self.assertAllEqual(in_value, reduced_tensor.numpy()) gathered_tensor = collective_ops.all_gather( in_tensor, group_size, group_key, instance_key) self.assertAllEqual(in_value, gathered_tensor.numpy())
def build_collective_gather(input_tensors, num_workers, collective_keys, communication_hint='AUTO', control_inputs=None): """Build a subgraph that does one full all-gather, using the collective Op. This method must be called in graph mode or inside a tf.function. Args: input_tensors: tensors within a single worker graph that are to be gathered together; must be one per device. num_workers: total number of workers with identical independent graphs that will be doing this same reduction. The reduction will actually include the corresponding tensors at all these workers. collective_keys: a CollectiveKeys object. communication_hint: string providing hint to runtime for choosing collective implementation. control_inputs: if not None, add control edges between control_inputs and (index-wise) corresponding collective_gather tensors Returns: An array of final tensors, one per device, computed by the full gather. """ assert not context.executing_eagerly(), ( 'build_collective_gather can only be called in graph mode or inside ' 'tf.function') group_size = len(input_tensors) * num_workers if group_size < 2: return input_tensors group_key = collective_keys.get_group_key_of_tensors(input_tensors) instance_key = collective_keys.get_op_instance_key() out_tensors = [] for idx, input_tensor in enumerate(input_tensors): with ops.device(input_tensor.device): with ops.control_dependencies( _control_input(input_tensors, control_inputs, idx)): out_tensor = collective_ops.all_gather(input_tensor, group_size, group_key, instance_key, communication_hint) out_tensors.append(out_tensor) return out_tensors
def testCollectiveGroupSizeOne(self): self._ensure_context_initialized() group_size = 1 group_key = 100 instance_key = 100 in_value = [1., 2., 3., 4.] in_tensor = constant_op.constant(in_value) with ops.device('/GPU:0'): reduced_tensor = collective_ops.all_reduce( in_tensor, group_size, group_key, instance_key, 'Add', 'Id', communication_hint='nccl') self.assertAllEqual(in_value, reduced_tensor.numpy()) with ops.device('/GPU:0'): gathered_tensor = collective_ops.all_gather( in_tensor, group_size, group_key, instance_key) self.assertAllEqual(in_value, gathered_tensor.numpy())
def all_reduce_indexed_slices(self, input_slices, communication_hint='AUTO', timeout=0): """All-reduce an IndexedSlices. This method must be called inside a tf.function. Args: input_slices: an IndexedSlices. communication_hint: string providing hint to runtime for choosing collective implementation. timeout: a float. The timeout in seconds. Returns: The reduced IndexedSlices. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError( 'all_reduce_indexed_slices in eager mode is not supported') gather_length_key = self._collective_keys.get_instance_key( self._group_key, self._device) gather_indices_key = self._collective_keys.get_instance_key( self._group_key, self._device) gather_values_key = self._collective_keys.get_instance_key( self._group_key, self._device) reduce_densified_key = self._collective_keys.get_instance_key( self._group_key, self._device) # Current CollectiveAllGather implementations require input IndexedSlices to # have consistent length across the board, we handle the reduction of # IndexedSlices as follows: # 1. Gather the lengths of IndexedSlices from all participants. # 2. If they have consistent length, apply all_gather. # 3. Otherwise convert IndexedSlices to dense tensors and apply # all_reduce. with ops.device(self._device): def all_gather(): """Use all_gather to aggregate `IndexedSlices`.""" all_values = collective_ops.all_gather(input_slices.values, self._group_size, self._group_key, gather_values_key, communication_hint, timeout=timeout) # Add control dependency to order the all-gather. control = [all_values] if communication_hint == 'NCCL' else [] with ops.control_dependencies(control): all_indices = collective_ops.all_gather( input_slices.indices, self._group_size, self._group_key, gather_indices_key, communication_hint, timeout=timeout) return ops.IndexedSlices(values=all_values, indices=all_indices, dense_shape=input_slices.dense_shape) def densify_and_all_reduce(): """Use all_reduce to aggregate `IndexedSlices`.""" densified = ops.convert_to_tensor(input_slices) reduced = collective_ops.all_reduce(densified, self._group_size, self._group_key, reduce_densified_key, 'Add', 'Id', [0], communication_hint, timeout=timeout) # We have to convert dense grad to IndexedSlice because all_reduce() # and all_gather() must have the same return type as required by # control_flow_ops.cond. return ops.IndexedSlices(values=reduced, indices=math_ops.range( array_ops.shape(reduced)[0]), dense_shape=input_slices.dense_shape) length = array_ops.shape(input_slices.indices) all_lengths = collective_ops.all_gather(length, self._group_size, self._group_key, gather_length_key, communication_hint, timeout=timeout) return control_flow_ops.cond( math_ops.equal(math_ops.reduce_max(all_lengths), math_ops.reduce_min(all_lengths)), all_gather, densify_and_all_reduce)
def build_collective_gather_indexed_slices(input_slices_list, devices, group_size, collective_keys, communication_hint='AUTO', control_inputs=None): """Build a subgraph that all-gathers IndexedSlices using the collective Op. This method must be called in graph mode or inside a tf.function. Args: input_slices_list: a list of IndexedSlices within a single worker graph that are to be gathered together; must be one per device. devices: a list of device strings to run the collective on. group_size: total number of devices globally that will be doing this same gathering. The gathering will actually include the corresponding tensors at all these workers. collective_keys: a CollectiveKeys object. communication_hint: string providing hint to runtime for choosing collective implementation. control_inputs: if not None, add control edges between control_inputs and (index-wise) corresponding collective_reduce tensors Returns: An array of final IndexedSlices, one per device, computed by the full gather. Raises: ValueError: if control_inputs is not None and doesn't match the length and devices of inputs. """ assert not context.executing_eagerly(), ( 'build_collective_gather_indexed_slices can only be called in graph mode' ' or inside tf.function') if len(input_slices_list) != len(devices): raise ValueError( 'collective requires one input IndexedSlice for each device, %d != %d' % (len(input_slices_list), len(devices))) if group_size < 2: return input_slices_list group_key = collective_keys.get_group_key(devices) gather_length_key = collective_keys.get_op_instance_key() gather_indices_key = collective_keys.get_op_instance_key() gather_values_key = collective_keys.get_op_instance_key() reduce_densified_key = collective_keys.get_op_instance_key() # Current CollectiveAllGather implementations require input IndexedSlices to # have consistent length across the board, we handle the reduction of # IndexedSlices as follows: # 1. Gather the lengths of IndexedSlices from all participants. # 2. If they have consistent length, apply all_gather. # 3. Otherwise convert IndexedSlices to dense tensors and apply # all_reduce. out_slices_list = [] for idx, input_slices in enumerate(input_slices_list): # pylint: disable = cell-var-from-loop with ops.device(devices[idx]): def all_gather(): """Use all_gather to aggregate `IndexedSlices`.""" all_values = collective_ops.all_gather(input_slices.values, group_size, group_key, gather_values_key, communication_hint) # Add control dependency to order the all-gather. control = [all_values] if communication_hint == 'NCCL' else [] with ops.control_dependencies(control): all_indices = collective_ops.all_gather( input_slices.indices, group_size, group_key, gather_indices_key, communication_hint) return ops.IndexedSlices(values=all_values, indices=all_indices, dense_shape=input_slices.dense_shape) def densify_and_all_reduce(): """Use all_reduce to aggregate `IndexedSlices`.""" densified = ops.convert_to_tensor(input_slices) reduced = collective_ops.all_reduce(densified, group_size, group_key, reduce_densified_key, 'Add', 'Id', [0], communication_hint) # We have to convert dense grad to IndexedSlice because all_reduce() # and all_gather() must have the same return type as required by # control_flow_ops.cond. return ops.IndexedSlices(values=reduced, indices=math_ops.range( array_ops.shape(reduced)[0]), dense_shape=input_slices.dense_shape) length = array_ops.shape(input_slices.indices) with ops.control_dependencies( _control_input(input_slices, control_inputs, idx)): all_lengths = collective_ops.all_gather( length, group_size, group_key, gather_length_key, communication_hint) out_slices = control_flow_ops.cond( math_ops.equal(math_ops.reduce_max(all_lengths), math_ops.reduce_min(all_lengths)), all_gather, densify_and_all_reduce) out_slices_list.append(out_slices) # pylint: enable=cell-var-from-loop return out_slices_list
def all_gather(self, input_tensor, axis, communication_hint='AUTO', timeout=0): """All-gather a dense tensor. This method must be called inside a tf.function. Args: input_tensor: a dense tensor. It must have the same rank on all replicas, and dimensions other than `axis` need to be the same as well. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). communication_hint: string providing hint to runtime for choosing collective implementation. Available options are `AUTO`, `NCCL`, and `RING`. timeout: a float. The timeout in seconds. Returns: The gathered Tensor. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError('all_gather in eager mode is not supported') instance_key_tensor = self._collective_keys.get_instance_key( self._group_key, self._device) instance_key_shape = self._collective_keys.get_instance_key( self._group_key, self._device) with ops.device(self._device): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to # place it back. perm_pre = array_ops.concat( ([axis], math_ops.range(axis), math_ops.range(axis + 1, array_ops.rank(input_tensor))), axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad gathered_shape = collective_ops.all_gather( array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), self._group_size, self._group_key, instance_key_shape, communication_hint, timeout=timeout) first_dims = gathered_shape[:, 0] full_axis_dim = math_ops.reduce_max(first_dims) padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather gather_padded_out_tensor = collective_ops.all_gather( padded_input_tensor, self._group_size, self._group_key, instance_key_tensor, communication_hint, timeout=timeout) # 4. Unpad split_tensors = [] for i in range(first_dims.shape[0]): start_pos = i * full_axis_dim split_tensors.append( gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) out_tensor_t = array_ops.concat(split_tensors, 0) # 5. Transpose back perm_after = array_ops.concat( (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) return array_ops.transpose(out_tensor_t, perm=perm_after)
def build_collective_gather(input_tensors, devices, group_size, collective_keys, axis, communication_hint='AUTO', control_inputs=None, timeout=None): """Build a subgraph that does one full all-gather, using the collective Op. This method must be called in graph mode or inside a tf.function. Args: input_tensors: tensors within a single worker graph that are to be gathered together; must be one per device. Input tensors cannot have rank 0. devices: a list of device strings to run the collective on. group_size: total number of devices globally that will be doing this same gathering. The gathering will actually include the corresponding tensors at all these workers. collective_keys: a CollectiveKeys object. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). communication_hint: string providing hint to runtime for choosing collective implementation. Available options are `AUTO`, `NCCL`, and `RING`. control_inputs: if not None, add control edges between control_inputs and (index-wise) corresponding collective_gather tensors timeout: a float or None. The timeout in seconds. Returns: An array of final tensors, one per device, computed by the full gather. """ if len(input_tensors) != len(devices): raise ValueError( 'collective requires one input tensor for each device, %d != %d' % (len(input_tensors), len(devices))) if group_size < 2: return input_tensors group_key = collective_keys.get_group_key(devices) instance_key_tensor = collective_keys.get_op_instance_key() instance_key_shape = collective_keys.get_op_instance_key() out_tensors = [] for idx, input_tensor in enumerate(input_tensors): with ops.device(devices[idx]), ops.control_dependencies( _control_input(devices, control_inputs, idx)): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to # place it back. perm_pre = array_ops.concat( ([axis], math_ops.range(axis), math_ops.range(axis + 1, array_ops.rank(input_tensor))), axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad gathered_shape = collective_ops.all_gather( array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), group_size, group_key, instance_key_shape, communication_hint, timeout=timeout) first_dims = gathered_shape[:, 0] full_axis_dim = math_ops.reduce_max(first_dims) padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather gather_padded_out_tensor = collective_ops.all_gather( padded_input_tensor, group_size, group_key, instance_key_tensor, communication_hint, timeout=timeout) # 4. Unpad split_tensors = [] for i in range(first_dims.shape[0]): start_pos = i * full_axis_dim split_tensors.append(gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) out_tensor_t = array_ops.concat(split_tensors, 0) # 5. Transpose back perm_after = array_ops.concat( (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) out_tensor = array_ops.transpose(out_tensor_t, perm=perm_after) out_tensors.append(out_tensor) return out_tensors