def testZeroBytesPerPack(self):
   values = [
       array_ops.ones([1], dtype=dtypes.float32),
       array_ops.ones([2], dtype=dtypes.float32),
   ]
   per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
   packs = cross_device_utils.pack_by_size(
       per_replica_values, bytes_per_pack=0)
   self.assertLen(packs, 1)
   self.assertLen(packs[0], 2)
   self.assertShape(packs[0][0], [1])
   self.assertShape(packs[0][1], [2])
 def testInconsistentShape(self):
   per_replica_values = [
       value_lib.PerReplica([
           array_ops.ones([10, 10], dtype=dtypes.float32),
           array_ops.ones([10, 10], dtype=dtypes.float32),
       ]),
       value_lib.PerReplica([
           array_ops.ones([10, 10], dtype=dtypes.float32),
           input_layer.Input(
               shape=(10), batch_size=None, dtype=dtypes.float32),
       ]),
   ]
   packs = cross_device_utils.pack_by_size(
       per_replica_values, bytes_per_pack=1)
   self.assertLen(packs, 1)
   self.assertEqual(packs[0], per_replica_values)
Exemplo n.º 3
0
    def testUnknownShape(self):
        def create_placeholder(shape, dtype):
            with ops.Graph().as_default():
                return array_ops.placeholder(dtype=dtype, shape=shape)

        per_replica_values = [
            value_lib.PerReplica([
                array_ops.ones([10, 10], dtype=dtypes.float32),
                array_ops.ones([10, 10], dtype=dtypes.float32),
            ]),
            value_lib.PerReplica([
                array_ops.ones([10, 10], dtype=dtypes.float32),
                create_placeholder([None, 10], dtype=dtypes.float32),
            ]),
        ]
        packs = cross_device_utils.pack_by_size(per_replica_values,
                                                bytes_per_pack=1)
        self.assertLen(packs, 1)
        self.assertEqual(packs[0], per_replica_values)
 def testPreferLargerPack(self):
   # Each packs except the last one should be equal or larger than
   # bytes_per_pack.
   values = [
       # size = 2 * 4 * 4 * 4 = 128
       array_ops.ones([2, 4, 4], dtype=dtypes.float32),
       # size = 8 * 4 = 32
       array_ops.ones([8], dtype=dtypes.int32),
       # size = 10 * 10 * 8 = 800
       array_ops.ones([10, 10], dtype=dtypes.int64),
       # size = 1 * 4 = 4
       array_ops.ones([1], dtype=dtypes.int32),
   ]
   per_replica_values = [value_lib.PerReplica([v, v]) for v in values]
   packs = cross_device_utils.pack_by_size(
       per_replica_values, bytes_per_pack=200)
   self.assertLen(packs, 2)
   self.assertLen(packs[0], 3)
   self.assertShape(packs[0][0], [2, 4, 4])
   self.assertShape(packs[0][1], [8])
   self.assertShape(packs[0][2], [10, 10])
   self.assertLen(packs[1], 1)
   self.assertShape(packs[1][0], [1])
Exemplo n.º 5
0
    def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
                                   experimental_hints):
        """All-reduce across all workers in a batch."""

        batch_size = len(per_replica_values)
        # Pass self._communication to the runtime as a communication hint.
        communication = self._communication.value
        # For now, we use NCCL only when batch_size > 1.
        # TODO(b/132575814): switch to NCCL for all collectives when communication
        # is NCCL.
        if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
            communication = CollectiveCommunication.AUTO.value

        # Reverse the lists so that there's better chance that values follows
        # the order in which they are calculated (e.g. when they're gradients), so
        # as to overlap calculation with communication. However, this may not be
        # optimal for cases like gradients of complicated non-sequential models.
        #
        # Note that we reverse the list before packing so that the first pack won't
        # be too small, since it's more likely for first few packs to have long
        # queuing time due to concurrent intense computation.
        #
        # TODO(b/147393503): explore solutions for optimal ordering.
        packs = cross_device_utils.pack_by_size(
            list(reversed(per_replica_values)),
            experimental_hints.bytes_per_pack)

        if batch_size > 1:
            logging.info(
                "Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
                "group_size = %d, communication_hint = %s, num_packs = %d",
                batch_size, len(self._devices), self._group_size,
                communication, len(packs))
        else:
            logging.log_first_n(
                logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
                "num_devices = %d, group_size = %d, communication_hint = %s, "
                "num_packs = %d" %
                (batch_size, len(self._devices), self._group_size,
                 communication, len(packs)), 10)

        reduced_values = []
        with self._lock:
            for pack in packs:
                # By placing all CollectiveReduce ops in a pack under single name scope,
                # we ensure they will be picked up by the `ScopedAllocator` grappler
                # optimizer and packed into a single all-reduce.
                with ops.name_scope("allreduce"):
                    for per_replica in pack:
                        # Add control dependencies per device from the last gradients to the
                        # current set, in order to serialize NCCL launches.
                        if (communication == CollectiveCommunication.NCCL.value
                                and reduced_values):
                            control_inputs = list(reduced_values[-1])
                        else:
                            control_inputs = None
                        reduced_values.append(
                            cross_device_utils.build_collective_reduce(
                                per_replica.values,
                                self._devices,
                                self._group_size,
                                self._collective_keys,
                                "Add",
                                "Id",
                                communication,
                                control_inputs,
                                executors=self._executors,
                                timeout=experimental_hints.timeout_seconds))

        for e in self._executors:
            e.wait()

        mirrored = []
        # Reverse the order of reduced value to recover the order in the input.
        for value in reversed(reduced_values):
            if reduce_op == reduce_util.ReduceOp.MEAN:
                for i, v in enumerate(value):
                    with ops.device(v.device):
                        value[i] = v / self._group_size
            mirrored.append(
                distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
        return mirrored