def _calculate_mean_and_var(self, x, axes, keep_dims): with backend.name_scope('moments'): # The dynamic range of fp16 is too limited to support the collection of # sufficient statistics. As a workaround we simply perform the operations # on 32-bit floats before converting the mean and variance back to fp16 y = math_ops.cast( x, dtypes.float32) if x.dtype == dtypes.float16 else x replica_ctx = ds.get_replica_context() if replica_ctx: local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True) batch_size = math_ops.cast( array_ops.shape_v2(y)[axes[0]], dtypes.float32) # TODO(b/163099951): batch the all-reduces once we sort out the ordering # issue for NCCL. We don't have a mechanism to launch NCCL in the same # order in each replica nowadays, so we limit NCCL to batch all-reduces. y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum) y_squared_sum = replica_ctx.all_reduce( reduce_util.ReduceOp.SUM, local_squared_sum) global_batch_size = replica_ctx.all_reduce( reduce_util.ReduceOp.SUM, batch_size) axes_vals = [(array_ops.shape_v2(y))[axes[i]] for i in range(1, len(axes))] multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32) multiplier = multiplier * global_batch_size mean = y_sum / multiplier y_squared_mean = y_squared_sum / multiplier # var = E(x^2) - E(x)^2 variance = y_squared_mean - math_ops.square(mean) else: # Compute true mean while keeping the dims for proper broadcasting. mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean') # sample variance, not unbiased variance # Note: stop_gradient does not change the gradient that gets # backpropagated to the mean from the variance calculation, # because that gradient is zero variance = math_ops.reduce_mean(math_ops.squared_difference( y, array_ops.stop_gradient(mean)), axes, keepdims=True, name='variance') if not keep_dims: mean = array_ops.squeeze(mean, axes) variance = array_ops.squeeze(variance, axes) if x.dtype == dtypes.float16: return (math_ops.cast(mean, dtypes.float16), math_ops.cast(variance, dtypes.float16)) else: return (mean, variance)
def _calculate_mean_and_var(self, x, axes, keep_dims): with ops.name_scope('moments', values=[x, axes]): # The dynamic range of fp16 is too limited to support the collection of # sufficient statistics. As a workaround we simply perform the operations # on 32-bit floats before converting the mean and variance back to fp16 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x if horovod_enabled(): num_shards = hvd.size() else: num_shards = 1 if num_shards > 1: local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True) batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32) # y_sum, y_squared_sum, global_batch_size = ( # replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [ # local_sum, local_squared_sum, batch_size])) # hvd_info(f'local_sum {local_sum.shape}, local_squared_sum {local_squared_sum.shape}') y_sum = hvd.allreduce(local_sum, average=False) y_squared_sum = hvd.allreduce(local_squared_sum, average=False) global_batch_size = batch_size * num_shards axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))] multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32) multiplier = multiplier * global_batch_size mean = y_sum / multiplier y_squared_mean = y_squared_sum / multiplier # var = E(x^2) - E(x)^2 variance = y_squared_mean - math_ops.square(mean) else: # Compute true mean while keeping the dims for proper broadcasting. mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean') # sample variance, not unbiased variance # Note: stop_gradient does not change the gradient that gets # backpropagated to the mean from the variance calculation, # because that gradient is zero variance = math_ops.reduce_mean( math_ops.squared_difference(y, array_ops.stop_gradient(mean)), axes, keepdims=True, name='variance') if not keep_dims: mean = array_ops.squeeze(mean, axes) variance = array_ops.squeeze(variance, axes) if x.dtype == dtypes.float16: return (math_ops.cast(mean, dtypes.float16), math_ops.cast(variance, dtypes.float16)) else: return (mean, variance)
def call(self, inputs): if self._channels_first: rank = inputs.shape.rank if rank and rank > 1: # Switch to channels-last format. permutation = [0] permutation.extend(range(2, rank)) permutation.append(1) inputs = array_ops.transpose(inputs, perm=permutation) if context.executing_eagerly(): # Full static shape is guaranteed to be available. # Performance: Using `constant_op` is much faster than passing a list. flattened_shape = constant_op.constant([inputs.shape[0], -1]) return gen_array_ops.reshape(inputs, flattened_shape) else: input_shape = inputs.shape rank = input_shape.rank if rank == 1: return array_ops.expand_dims_v2(inputs, axis=1) else: batch_dim = tensor_shape.dimension_value(input_shape[0]) non_batch_dims = input_shape[1:] # Reshape in a way that preserves as much shape info as possible. if non_batch_dims.is_fully_defined(): last_dim = int( functools.reduce(operator.mul, non_batch_dims)) flattened_shape = constant_op.constant([-1, last_dim]) elif batch_dim is not None: flattened_shape = constant_op.constant( [int(batch_dim), -1]) else: flattened_shape = [array_ops.shape_v2(inputs)[0], -1] return array_ops.reshape(inputs, flattened_shape)
def call(self, inputs): bins = [math_ops.cast(array_ops.squeeze(self.bins), dtypes.float32)] def _bucketize_fn(inputs): return gen_boosted_trees_ops.BoostedTreesBucketize( float_values=[math_ops.cast(inputs, dtypes.float32)], bucket_boundaries=bins)[0] if tf_utils.is_ragged(inputs): integer_buckets = ragged_functional_ops.map_flat_values( _bucketize_fn, inputs) # Ragged map_flat_values doesn't touch the non-values tensors in the # ragged composite tensor. If this op is the only op a Keras model, # this can cause errors in Graph mode, so wrap the tensor in an identity. return array_ops.identity(integer_buckets) elif isinstance(inputs, sparse_tensor.SparseTensor): return sparse_tensor.SparseTensor( indices=array_ops.identity(inputs.indices), values=_bucketize_fn(inputs.values), dense_shape=array_ops.identity(inputs.dense_shape)) else: static_shape = inputs.get_shape() if any(dim is None for dim in static_shape.as_list()[1:]): raise NotImplementedError( "Discretization Layer requires known non-batch shape," "found {}".format(static_shape)) dynamic_shape = array_ops.shape_v2(inputs) # BoostedTreesBucketize only handles rank 1 inputs. We need to flatten our # inputs after batch size and vectorized_map over each sample. reshaped = array_ops.reshape(inputs, [dynamic_shape[0], -1]) return array_ops.reshape( control_flow_ops.vectorized_map(_bucketize_fn, reshaped), dynamic_shape)
def _pad_util(input_tensor, full_axis_dim): """Pad the `input_tensor`'s first dimension to be `full_axis_dim`.""" missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0] tensor_rank = array_ops.rank(input_tensor) paddings_axis = [[0, missing_axis_dim]] paddings = array_ops.concat([ paddings_axis, array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32) ], axis=0) padded_input_tensor = array_ops.pad(input_tensor, paddings) return padded_input_tensor
def test_vocab_list_sparse_input(self): layer = categorical.CategoryLookup( vocabulary=self._wire_vocabulary_file_name, num_oov_tokens=0) inp = np.asarray([['omar', ''], ['stringer', 'marlo'], ['marlo', 'omar']]) indices = array_ops.where_v2(math_ops.not_equal(inp, '')) sp_inp = sparse_tensor.SparseTensor( indices, array_ops.gather_nd_v2(inp, indices), dense_shape=array_ops.shape_v2(inp, out_type=dtypes.int64)) output = layer(sp_inp) self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertAllClose( np.asarray([[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]), output.indices) self.assertAllClose(np.asarray([0, 1, 2, 2, 0]), output.values)
def test_vocab_list_sparse_input(self): vocabulary_list = ['A', 'B', 'C', 'D', 'E'] layer = categorical.CategoryLookup( vocabulary=vocabulary_list, num_oov_tokens=0) inp = np.asarray([['A', ''], ['E', 'C'], ['D', 'A']]) indices = array_ops.where_v2(math_ops.not_equal(inp, '')) sp_inp = sparse_tensor.SparseTensor( indices, array_ops.gather_nd_v2(inp, indices), dense_shape=array_ops.shape_v2(inp, out_type=dtypes.int64)) output = layer(sp_inp) self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertAllClose( np.asarray([[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]), output.indices) self.assertAllClose(np.asarray([0, 4, 2, 3, 0]), output.values)
def slice_batch(index, batch): flattened_batch = nest.flatten(batch) flattened_output = [] norm_index = math_ops.cast(index % self._num_local_devices_per_replica, dtype=dtypes.int32) norm_index += self._partition_offset coords = self._mesh.coords(norm_index) coords = array_ops.reshape(coords, (1, -1)) for element, shard_counts, idx_matrix in zip( flattened_batch, self._all_shard_counts, self._index_matrices): indexes = math_ops.matmul(coords, idx_matrix) start = array_ops.reshape(indexes, (-1, )) size = array_ops.shape_v2( element, out_type=dtypes.int32) // shard_counts flattened_output.append( array_ops.slice(element, begin=start, size=size)) return nest.pack_sequence_as(batch, flattened_output)
def _subdiv_calculate_mean_and_var(self, inputs, reduction_axes, keep_dims): # calculate the net_sum = math_ops.reduce_sum(inputs, axis=reduction_axes, keepdims=keep_dims) squared_mean = math_ops.reduce_sum(math_ops.square(inputs), axis=reduction_axes, keepdims=keep_dims) if self._support_zero_size_input(): # Keras assumes that batch dimension is the first dimension for Batch # Normalization. input_batch_size = array_ops.shape(inputs)[0] else: input_batch_size = None # get the number of total params you are averaging including batchsize(local) axes_vals = [(array_ops.shape_v2(inputs))[i] for i in range(1, len(reduction_axes))] multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32) squared_mean = squared_mean / multiplier net_sum = net_sum / multiplier if input_batch_size is None: mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) else: batches_ = math_ops.cast(input_batch_size, self._param_dtype) mean = net_sum / batches_ variance = squared_mean / batches_ - math_ops.square( array_ops.stop_gradient(mean)) return mean, net_sum, variance, squared_mean, input_batch_size
def model(in_tensor): shape = array_ops.shape_v2(in_tensor) fill = array_ops.transpose_v2(array_ops.fill(shape, 1.)) return math_ops.matmul(fill, in_tensor)
def all_gather(self, input_tensor, axis, communication_hint='AUTO', timeout=0): """All-gather a dense tensor. This method must be called inside a tf.function. Args: input_tensor: a dense tensor. It must have the same rank on all replicas, and dimensions other than `axis` need to be the same as well. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). communication_hint: string providing hint to runtime for choosing collective implementation. Available options are `AUTO`, `NCCL`, and `RING`. timeout: a float. The timeout in seconds. Returns: The gathered Tensor. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError('all_gather in eager mode is not supported') with ops.device(self._device), \ ops.control_dependencies([array_ops.identity(input_tensor)]): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to # place it back. perm_pre = array_ops.concat( ([axis], math_ops.range(axis), math_ops.range(axis + 1, array_ops.rank(input_tensor))), axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad gathered_shape = self._all_gather(array_ops.expand_dims_v2( array_ops.shape_v2(input_tensor_t), axis=0), communication_hint, timeout=timeout) first_dims = gathered_shape[:, 0] full_axis_dim = math_ops.reduce_max(first_dims) padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather gather_padded_out_tensor = self._all_gather(padded_input_tensor, communication_hint, timeout=timeout) # 4. Unpad split_tensors = [] for i in range(self._group_size): start_pos = i * full_axis_dim split_tensors.append( gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) out_tensor_t = array_ops.concat(split_tensors, 0) # 5. Transpose back perm_after = array_ops.concat( (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) return array_ops.transpose(out_tensor_t, perm=perm_after)
def call(self, y_true, y_pred): """Invokes the `Loss` instance. Args: y_true: Ground truth values. y_pred: The predicted values. Returns: Loss values in the form of a Tensor """ gamma = self.gamma from_logits = self.from_logits axis = -1 y_true = tf.cast(y_true, y_pred.dtype) y_true = ops.convert_to_tensor_v2(y_true) y_pred = ops.convert_to_tensor_v2(y_pred) probs = y_pred # Reformat y_pred shapes if (not from_logits and not isinstance(y_pred, (ops.EagerTensor, variables_module.Variable)) and y_pred.op.type == 'Softmax') and not hasattr( y_pred, '_keras_history'): assert len(y_pred.op.inputs) == 1 y_pred = y_pred.op.inputs[0] from_logits = True # Clip y_pred to a minimum and maximum value if not from_logits: epsilon_ = constant_op.constant(K.epsilon(), y_pred.dtype.base_dtype) y_pred = clip_ops.clip_by_value(y_pred, epsilon_, 1 - epsilon_) y_pred = math_ops.log(y_pred) # Get dimensions of predictions tensor if isinstance(y_pred.shape, (tuple, list)): output_rank = len(y_pred.shape) else: output_rank = y_pred.shape.ndims if output_rank is not None: axis %= output_rank if axis != output_rank - 1: permutation = list( itertools.chain(range(axis), range(axis + 1, output_rank), [axis])) y_pred = array_ops.transpose(y_pred, perm=permutation) elif axis != -1: raise ValueError( 'Cannot compute sparse categorical crossentropy with `axis={}` on an ' 'output tensor with unknown rank'.format(axis)) # Reformat y_true shape and data type. y_true = cast(y_true, 'int64') output_shape = array_ops.shape_v2(y_pred) target_rank = y_true.shape.ndims update_shape = (target_rank is not None and output_rank is not None and target_rank != output_rank - 1) if update_shape: y_true = flatten(y_true) y_pred = array_ops.reshape(y_pred, [-1, output_shape[-1]]) # Calculate cross-entropy loss if py_any(_is_symbolic_tensor(v) for v in [y_true, y_pred]): with get_graph().as_default(): loss = nn.sparse_softmax_cross_entropy_with_logits_v2( labels=y_true, logits=y_pred) else: loss = nn.sparse_softmax_cross_entropy_with_logits_v2( labels=y_true, logits=y_pred) if update_shape and output_rank >= 3: loss = array_ops.reshape(loss, output_shape[:-1]) # Calculate focal modulation to be applied gamma = tf.convert_to_tensor(gamma, dtype=tf.dtypes.float32) scalar_gamma = gamma.shape.rank == 0 y_true_rank = y_true.shape.rank if not scalar_gamma: gamma = tf.gather(gamma, y_true, axis=0, batch_dims=y_true_rank) focal_modulation = K.pow(1 - tf.math.reduce_mean(probs, axis=1), gamma) focal_modulation = tf.gather(focal_modulation, y_true, axis=0, batch_dims=y_true_rank) loss = focal_modulation * loss return loss
def all_gather( self, input_tensor: core.TensorLike, axis: core.TensorLike, options: Optional[collective_util.Options] = None) -> core.Tensor: """All-gather a dense tensor. This method must be called inside a tf.function. Args: input_tensor: a dense tensor. It must have the same rank on all replicas, and dimensions other than `axis` need to be the same as well. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). options: an optional tf.distribute.experimental.CommunicationOptions. If provided, it overrides the default options. Returns: The gathered Tensor. Raises: RuntimeError: if called in eager mode. """ if context.executing_eagerly(): raise RuntimeError('all_gather is not supported in eager mode.') with ops.device(self._device), \ ops.control_dependencies([array_ops.identity(input_tensor)]): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to # place it back. perm_pre = array_ops.concat( ([axis], math_ops.range(axis), math_ops.range(axis + 1, array_ops.rank(input_tensor))), axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad gathered_shape = self._all_gather( array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), options) first_dims = gathered_shape[:, 0] full_axis_dim = math_ops.reduce_max(first_dims) padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather gather_padded_out_tensor = self._all_gather( padded_input_tensor, options) # 4. Unpad split_tensors = [] for i in range(self._group_size): start_pos = i * full_axis_dim split_tensors.append( gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) out_tensor_t = array_ops.concat(split_tensors, 0) # 5. Transpose back perm_after = array_ops.concat( (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) return array_ops.transpose(out_tensor_t, perm=perm_after)
def _subdiv_calculate_mean_and_var(self, x, axes, keep_dims): with K.name_scope('moments'): # The dynamic range of fp16 is too limited to support the collection of # sufficient statistics. As a workaround we simply perform the operations # on 32-bit floats before converting the mean and variance back to fp16 y = math_ops.cast( x, dtypes.float32) if x.dtype == dtypes.float16 else x replica_ctx = ds.get_replica_context() if replica_ctx: # local to me local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True) batch_size = math_ops.cast( array_ops.shape_v2(y)[0], dtypes.float32) # TODO(b/163099951): batch the all-reduces once we sort out the ordering # issue for NCCL. We don't have a mechanism to launch NCCL in the same # order in each replica nowadays, so we limit NCCL to batch all-reduces. # get the sum of all replicas (converge all devices) y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum) # get the sum from all replicas (converge all devices) y_squared_sum = replica_ctx.all_reduce( reduce_util.ReduceOp.SUM, local_squared_sum) # get the net batch size from all devices (converge all devices) input_batch_size = replica_ctx.all_reduce( reduce_util.ReduceOp.SUM, batch_size) #tf.print(replica_ctx.replica_id_in_sync_group, replica_ctx.num_replicas_in_sync, batch_size, self.aggregated_square_sum_batch, axes) # get the number of total params you are averaging (local) axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))] multiplier_ = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32) multiplier = multiplier_ * input_batch_size # conver mean var (locally) mean = y_sum / multiplier y_squared_mean = y_squared_sum / multiplier # var = E(x^2) - E(x)^2 variance = y_squared_mean - math_ops.square(mean) net_sum = y_sum / multiplier_ squared_mean = y_squared_sum / multiplier_ else: # mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean') # # sample variance, not unbiased variance # # Note: stop_gradient does not change the gradient that gets # # backpropagated to the mean from the variance calculation, # # because that gradient is zero # variance = math_ops.reduce_mean( # math_ops.squared_difference(y, array_ops.stop_gradient(mean)), # axes, # keepdims=True, # name='variance') net_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) squared_mean = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True) if self._support_zero_size_input(): # Keras assumes that batch dimension is the first dimension for Batch # Normalization. input_batch_size = array_ops.shape(y)[0] else: input_batch_size = None # get the number of total params you are averaging including batchsize(local) axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))] multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32) squared_mean = squared_mean / multiplier net_sum = net_sum / multiplier if input_batch_size is None: mean, variance = nn.moments(y, axes, keep_dims=True) input_batch_size = 0 else: batches_ = math_ops.cast(input_batch_size, self._param_dtype) # # if you only have one replica dont worry about it # # Compute true mean while keeping the dims for proper broadcasting. mean = net_sum / batches_ variance = squared_mean / batches_ - math_ops.square(mean) input_batch_size = math_ops.cast(input_batch_size, dtypes.int32) if not keep_dims: mean = array_ops.squeeze(mean, axes) net_sum = array_ops.squeeze(net_sum, axes) variance = array_ops.squeeze(variance, axes) squared_mean = array_ops.squeeze(squared_mean, axes) if x.dtype == dtypes.float16: return (math_ops.cast(mean, dtypes.float16), math_ops.cast(net_sum, dtypes.float16), math_ops.cast(variance, dtypes.float16), math_ops.cast(squared_mean, dtypes.float16), input_batch_size) else: return (mean, net_sum, variance, squared_mean, input_batch_size)
def build_collective_gather(input_tensors, devices, group_size, collective_keys, axis, communication_hint='AUTO', control_inputs=None, timeout=None): """Build a subgraph that does one full all-gather, using the collective Op. This method must be called in graph mode or inside a tf.function. Args: input_tensors: tensors within a single worker graph that are to be gathered together; must be one per device. Input tensors cannot have rank 0. devices: a list of device strings to run the collective on. group_size: total number of devices globally that will be doing this same gathering. The gathering will actually include the corresponding tensors at all these workers. collective_keys: a CollectiveKeys object. axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the range [0, rank(value)). communication_hint: string providing hint to runtime for choosing collective implementation. Available options are `AUTO`, `NCCL`, and `RING`. control_inputs: if not None, add control edges between control_inputs and (index-wise) corresponding collective_gather tensors timeout: a float or None. The timeout in seconds. Returns: An array of final tensors, one per device, computed by the full gather. """ if len(input_tensors) != len(devices): raise ValueError( 'collective requires one input tensor for each device, %d != %d' % (len(input_tensors), len(devices))) if group_size < 2: return input_tensors group_key = collective_keys.get_group_key(devices) instance_key_tensor = collective_keys.get_op_instance_key() instance_key_shape = collective_keys.get_op_instance_key() out_tensors = [] for idx, input_tensor in enumerate(input_tensors): with ops.device(devices[idx]), ops.control_dependencies( _control_input(devices, control_inputs, idx)): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to # place it back. perm_pre = array_ops.concat( ([axis], math_ops.range(axis), math_ops.range(axis + 1, array_ops.rank(input_tensor))), axis=0) input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) # 2. Pad gathered_shape = collective_ops.all_gather( array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), group_size, group_key, instance_key_shape, communication_hint, timeout=timeout) first_dims = gathered_shape[:, 0] full_axis_dim = math_ops.reduce_max(first_dims) padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) # 3. Gather gather_padded_out_tensor = collective_ops.all_gather( padded_input_tensor, group_size, group_key, instance_key_tensor, communication_hint, timeout=timeout) # 4. Unpad split_tensors = [] for i in range(first_dims.shape[0]): start_pos = i * full_axis_dim split_tensors.append(gather_padded_out_tensor[start_pos:start_pos + first_dims[i]]) out_tensor_t = array_ops.concat(split_tensors, 0) # 5. Transpose back perm_after = array_ops.concat( (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) out_tensor = array_ops.transpose(out_tensor_t, perm=perm_after) out_tensors.append(out_tensor) return out_tensors