def _batch_all_reduce(self, reduce_op, per_replica_values): """All reduce algorithm in a batch.""" logging.log_first_n( logging.INFO, "distributed batch_all_reduce invoked for batches size = %d with " "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " "and agg_small_grads_max_group = %d" % (len(per_replica_values), self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = sorted(per_replica_values[0].devices) device_grads = _group_value_by_device(per_replica_values) # The all reduce library requires fully defined shapes. # TODO(yuefengz): when tensor sharding is not needed, static shapes are not # required as well. for device_grad in device_grads: for grad, _ in device_grad: if not grad.shape.is_fully_defined(): raise ValueError("Shape is unknown for node %r" % grad) remaining_grads = device_grads aggregated_grads = [] for spec_tuple in self._all_reduce_spec: if spec_tuple.limit < 0: this_grads = remaining_grads remaining_grads = [] else: (this_grads, remaining_grads) = cross_device_utils.split_grads_by_size( spec_tuple.limit, remaining_grads) if this_grads: device_grad_packs, tensor_packer = _pack_tensors( this_grads, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group) range_agg_grads = cross_device_utils.sum_gradients_all_reduce( self._worker_devices, device_grad_packs, len(self._worker_devices), spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer) if not aggregated_grads: aggregated_grads = range_agg_grads else: assert len(aggregated_grads) == len(range_agg_grads) for i in range(len(aggregated_grads)): aggregated_grads[i] += range_agg_grads[i] assert not remaining_grads return _ungroup_and_make_mirrored(aggregated_grads, destinations, reduce_op)