def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
        """All-reduce IndexedSlices across all workers in a batch."""

        logging.log_first_n(
            logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
            "%d all-reduces, num_workers = %d" %
            (len(per_replica_values), self._num_workers), 10)

        chunked_gv = self._make_gradient_chunks(per_replica_values,
                                                self._all_reduce_merge_scope)

        reduced_gv_list = []
        for chunk in chunked_gv:
            with ops.name_scope("allreduce"):
                for grad_and_vars in chunk:
                    # Gradients for the same variable but from different devices.
                    scaled_grads = [g for g, _ in grad_and_vars]

                    values = [g.values for g in scaled_grads]
                    indices = [g.indices for g in scaled_grads]
                    assert len(values) == len(indices)

                    # Build two separate allgathers, one for values, the other one for
                    # indices.
                    gathered_values = cross_device_utils.build_collective_gather(
                        values, self._num_workers, self._collective_keys)
                    gathered_indices = cross_device_utils.build_collective_gather(
                        indices, self._num_workers, self._collective_keys)
                    assert len(gathered_values) == len(gathered_indices)

                    collective_reduced = []
                    for i in range(len(values)):
                        reduced = ops.IndexedSlices(
                            gathered_values[i],
                            gathered_indices[i],
                            dense_shape=scaled_grads[i].dense_shape)
                        collective_reduced.append(reduced)

                    result = []
                    for (_, v), g in zip(grad_and_vars, collective_reduced):
                        result.append([g, v])
                    reduced_gv_list.append(result)

        new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
        return _ungroup_and_make_mirrored(
            new_device_grads,
            per_replica_values[0],
            reduce_op,
            num_between_graph_workers=self._num_workers)
  def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
    """All-reduce IndexedSlices across all workers in a batch."""

    logging.log_first_n(
        logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
        "%d all-reduces, num_workers = %d" %
        (len(per_replica_values), self._num_workers), 10)

    chunked_gv = self._make_gradient_chunks(per_replica_values,
                                            self._all_reduce_merge_scope)

    reduced_gv_list = []
    for chunk in chunked_gv:
      with ops.name_scope("allreduce"):
        for grad_and_vars in chunk:
          # Gradients for the same variable but from different devices.
          scaled_grads = [g for g, _ in grad_and_vars]

          values = [g.values for g in scaled_grads]
          indices = [g.indices for g in scaled_grads]
          assert len(values) == len(indices)

          # Build two separate allgathers, one for values, the other one for
          # indices.
          gathered_values = cross_device_utils.build_collective_gather(
              values, self._num_workers, self._collective_keys)
          gathered_indices = cross_device_utils.build_collective_gather(
              indices, self._num_workers, self._collective_keys)
          assert len(gathered_values) == len(gathered_indices)

          collective_reduced = []
          for i in range(len(values)):
            reduced = ops.IndexedSlices(
                gathered_values[i],
                gathered_indices[i],
                dense_shape=scaled_grads[i].dense_shape)
            collective_reduced.append(reduced)

          result = []
          for (_, v), g in zip(grad_and_vars, collective_reduced):
            result.append([g, v])
          reduced_gv_list.append(result)

    new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
    return _ungroup_and_make_mirrored(
        new_device_grads,
        per_replica_values[0],
        reduce_op,
        num_between_graph_workers=self._num_workers)
Beispiel #3
0
 def gather_fn():
     gathered = cross_device_utils.build_collective_gather(inputs,
                                                           devices,
                                                           group_size,
                                                           collective_keys,
                                                           axis=0)
     return distribute_utils.update_regroup(strategy.extended,
                                            gathered,
                                            group=True)