def testZeroBytesPerPack(self): values = [ array_ops.ones([1], dtype=dtypes.float32), array_ops.ones([2], dtype=dtypes.float32), ] per_replica_values = [value_lib.PerReplica([v, v]) for v in values] packs = cross_device_utils.pack_by_size( per_replica_values, bytes_per_pack=0) self.assertLen(packs, 1) self.assertLen(packs[0], 2) self.assertShape(packs[0][0], [1]) self.assertShape(packs[0][1], [2])
def testInconsistentShape(self): per_replica_values = [ value_lib.PerReplica([ array_ops.ones([10, 10], dtype=dtypes.float32), array_ops.ones([10, 10], dtype=dtypes.float32), ]), value_lib.PerReplica([ array_ops.ones([10, 10], dtype=dtypes.float32), input_layer.Input( shape=(10), batch_size=None, dtype=dtypes.float32), ]), ] packs = cross_device_utils.pack_by_size( per_replica_values, bytes_per_pack=1) self.assertLen(packs, 1) self.assertEqual(packs[0], per_replica_values)
def testUnknownShape(self): def create_placeholder(shape, dtype): with ops.Graph().as_default(): return array_ops.placeholder(dtype=dtype, shape=shape) per_replica_values = [ value_lib.PerReplica([ array_ops.ones([10, 10], dtype=dtypes.float32), array_ops.ones([10, 10], dtype=dtypes.float32), ]), value_lib.PerReplica([ array_ops.ones([10, 10], dtype=dtypes.float32), create_placeholder([None, 10], dtype=dtypes.float32), ]), ] packs = cross_device_utils.pack_by_size(per_replica_values, bytes_per_pack=1) self.assertLen(packs, 1) self.assertEqual(packs[0], per_replica_values)
def testPreferLargerPack(self): # Each packs except the last one should be equal or larger than # bytes_per_pack. values = [ # size = 2 * 4 * 4 * 4 = 128 array_ops.ones([2, 4, 4], dtype=dtypes.float32), # size = 8 * 4 = 32 array_ops.ones([8], dtype=dtypes.int32), # size = 10 * 10 * 8 = 800 array_ops.ones([10, 10], dtype=dtypes.int64), # size = 1 * 4 = 4 array_ops.ones([1], dtype=dtypes.int32), ] per_replica_values = [value_lib.PerReplica([v, v]) for v in values] packs = cross_device_utils.pack_by_size( per_replica_values, bytes_per_pack=200) self.assertLen(packs, 2) self.assertLen(packs[0], 3) self.assertShape(packs[0][0], [2, 4, 4]) self.assertShape(packs[0][1], [8]) self.assertShape(packs[0][2], [10, 10]) self.assertLen(packs[1], 1) self.assertShape(packs[1][0], [1])
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, experimental_hints): """All-reduce across all workers in a batch.""" batch_size = len(per_replica_values) # Pass self._communication to the runtime as a communication hint. communication = self._communication.value # For now, we use NCCL only when batch_size > 1. # TODO(b/132575814): switch to NCCL for all collectives when communication # is NCCL. if self._communication == CollectiveCommunication.NCCL and batch_size == 1: communication = CollectiveCommunication.AUTO.value # Reverse the lists so that there's better chance that values follows # the order in which they are calculated (e.g. when they're gradients), so # as to overlap calculation with communication. However, this may not be # optimal for cases like gradients of complicated non-sequential models. # # Note that we reverse the list before packing so that the first pack won't # be too small, since it's more likely for first few packs to have long # queuing time due to concurrent intense computation. # # TODO(b/147393503): explore solutions for optimal ordering. packs = cross_device_utils.pack_by_size( list(reversed(per_replica_values)), experimental_hints.bytes_per_pack) if batch_size > 1: logging.info( "Collective batch_all_reduce: %d all-reduces, num_devices = %d, " "group_size = %d, communication_hint = %s, num_packs = %d", batch_size, len(self._devices), self._group_size, communication, len(packs)) else: logging.log_first_n( logging.INFO, "Collective batch_all_reduce: %d all-reduces, " "num_devices = %d, group_size = %d, communication_hint = %s, " "num_packs = %d" % (batch_size, len(self._devices), self._group_size, communication, len(packs)), 10) reduced_values = [] with self._lock: for pack in packs: # By placing all CollectiveReduce ops in a pack under single name scope, # we ensure they will be picked up by the `ScopedAllocator` grappler # optimizer and packed into a single all-reduce. with ops.name_scope("allreduce"): for per_replica in pack: # Add control dependencies per device from the last gradients to the # current set, in order to serialize NCCL launches. if (communication == CollectiveCommunication.NCCL.value and reduced_values): control_inputs = list(reduced_values[-1]) else: control_inputs = None reduced_values.append( cross_device_utils.build_collective_reduce( per_replica.values, self._devices, self._group_size, self._collective_keys, "Add", "Id", communication, control_inputs, executors=self._executors, timeout=experimental_hints.timeout_seconds)) for e in self._executors: e.wait() mirrored = [] # Reverse the order of reduced value to recover the order in the input. for value in reversed(reduced_values): if reduce_op == reduce_util.ReduceOp.MEAN: for i, v in enumerate(value): with ops.device(v.device): value[i] = v / self._group_size mirrored.append( distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) return mirrored