Example #1
0
    def execute(
        self,
        input_blocks: BlockList,
        output_num_blocks: int,
        clear_input_blocks: bool,
        *,
        map_ray_remote_args: Optional[Dict[str, Any]] = None,
        reduce_ray_remote_args: Optional[Dict[str, Any]] = None,
    ) -> Tuple[BlockList, Dict[str, List[BlockMetadata]]]:
        input_blocks_list = input_blocks.get_blocks()
        input_num_blocks = len(input_blocks_list)

        if map_ray_remote_args is None:
            map_ray_remote_args = {}
        if reduce_ray_remote_args is None:
            reduce_ray_remote_args = {}
        if "scheduling_strategy" not in reduce_ray_remote_args:
            reduce_ray_remote_args = reduce_ray_remote_args.copy()
            reduce_ray_remote_args["scheduling_strategy"] = "SPREAD"

        shuffle_map = cached_remote_fn(self.map)
        shuffle_reduce = cached_remote_fn(self.reduce)

        map_bar = ProgressBar("Shuffle Map", total=input_num_blocks)

        shuffle_map_out = [
            shuffle_map.options(
                **map_ray_remote_args,
                num_returns=1 + output_num_blocks,
            ).remote(i, block, output_num_blocks, *self._map_args)
            for i, block in enumerate(input_blocks_list)
        ]

        # The first item returned is the BlockMetadata.
        shuffle_map_metadata = []
        for i, refs in enumerate(shuffle_map_out):
            shuffle_map_metadata.append(refs[0])
            shuffle_map_out[i] = refs[1:]

        # Eagerly delete the input block references in order to eagerly release
        # the blocks' memory.
        del input_blocks_list
        if clear_input_blocks:
            input_blocks.clear()
        shuffle_map_metadata = map_bar.fetch_until_complete(
            shuffle_map_metadata)
        map_bar.close()

        reduce_bar = ProgressBar("Shuffle Reduce", total=output_num_blocks)
        shuffle_reduce_out = [
            shuffle_reduce.options(
                **reduce_ray_remote_args,
                num_returns=2,
            ).remote(
                *self._reduce_args,
                *[shuffle_map_out[i][j] for i in range(input_num_blocks)],
            ) for j in range(output_num_blocks)
        ]
        # Eagerly delete the map block references in order to eagerly release
        # the blocks' memory.
        del shuffle_map_out
        new_blocks, new_metadata = zip(*shuffle_reduce_out)
        new_metadata = reduce_bar.fetch_until_complete(list(new_metadata))
        reduce_bar.close()

        stats = {
            "map": shuffle_map_metadata,
            "reduce": new_metadata,
        }

        return BlockList(list(new_blocks), list(new_metadata)), stats
Example #2
0
    def _apply(
        self,
        fn: Any,
        remote_args: dict,
        block_list: BlockList,
        clear_input_blocks: bool,
        name: Optional[str] = None,
    ) -> BlockList:
        context = DatasetContext.get_current()

        # Handle empty datasets.
        if block_list.initial_num_blocks() == 0:
            return block_list

        blocks = block_list.get_blocks_with_metadata()
        if name is None:
            name = "map"
        name = name.title()
        map_bar = ProgressBar(name, total=len(blocks))

        if context.block_splitting_enabled:
            map_block = cached_remote_fn(_map_block_split).options(
                **remote_args)
            refs = [map_block.remote(b, fn, m.input_files) for b, m in blocks]
        else:
            map_block = cached_remote_fn(_map_block_nosplit).options(
                **dict(remote_args, num_returns=2))
            all_refs = [
                map_block.remote(b, fn, m.input_files) for b, m in blocks
            ]
            data_refs = [r[0] for r in all_refs]
            refs = [r[1] for r in all_refs]

        # Release input block references.
        if clear_input_blocks:
            del blocks
            block_list.clear()

        # Common wait for non-data refs.
        try:
            results = map_bar.fetch_until_complete(refs)
        except (ray.exceptions.RayTaskError, KeyboardInterrupt) as e:
            # One or more mapper tasks failed, or we received a SIGINT signal
            # while waiting; either way, we cancel all map tasks.
            for ref in refs:
                ray.cancel(ref)
            # Wait until all tasks have failed or been cancelled.
            for ref in refs:
                try:
                    ray.get(ref)
                except (ray.exceptions.RayTaskError,
                        ray.exceptions.TaskCancelledError):
                    pass
            # Reraise the original task failure exception.
            raise e from None

        new_blocks, new_metadata = [], []
        if context.block_splitting_enabled:
            for result in results:
                for block, metadata in result:
                    new_blocks.append(block)
                    new_metadata.append(metadata)
        else:
            for block, metadata in zip(data_refs, results):
                new_blocks.append(block)
                new_metadata.append(metadata)
        return BlockList(list(new_blocks), list(new_metadata))
Example #3
0
    def _apply(
        self,
        fn: Any,
        remote_args: dict,
        block_list: BlockList,
        clear_input_blocks: bool,
        name: Optional[str] = None,
    ) -> BlockList:
        """Note: this is not part of the Dataset public API."""
        context = DatasetContext.get_current()

        blocks_in = block_list.get_blocks_with_metadata()

        # Early release block references.
        if clear_input_blocks:
            block_list.clear()

        orig_num_blocks = len(blocks_in)
        results = []
        if name is None:
            name = "map"
        name = name.title()
        map_bar = ProgressBar(name, total=orig_num_blocks)

        class BlockWorker:
            def ready(self):
                return "ok"

            def map_block_split(self, block: Block,
                                input_files: List[str]) -> BlockPartition:
                return _map_block_split(block, fn, input_files)

            @ray.method(num_returns=2)
            def map_block_nosplit(
                    self, block: Block,
                    input_files: List[str]) -> Tuple[Block, BlockMetadata]:
                return _map_block_nosplit(block, fn, input_files)

        if not remote_args:
            remote_args["num_cpus"] = 1

        remote_args["scheduling_strategy"] = context.scheduling_strategy

        BlockWorker = ray.remote(**remote_args)(BlockWorker)

        workers = [BlockWorker.remote() for _ in range(self.min_size)]
        tasks = {w.ready.remote(): w for w in workers}
        tasks_in_flight = collections.defaultdict(int)
        metadata_mapping = {}
        block_indices = {}
        ready_workers = set()

        while len(results) < orig_num_blocks:
            ready, _ = ray.wait(list(tasks.keys()),
                                timeout=0.01,
                                num_returns=1,
                                fetch_local=False)
            if not ready:
                if (len(workers) < self.max_size
                        and len(ready_workers) / len(workers) > 0.8):
                    w = BlockWorker.remote()
                    workers.append(w)
                    tasks[w.ready.remote()] = w
                    map_bar.set_description(
                        "Map Progress ({} actors {} pending)".format(
                            len(ready_workers),
                            len(workers) - len(ready_workers)))
                continue

            [obj_id] = ready
            worker = tasks.pop(obj_id)

            # Process task result.
            if worker in ready_workers:
                results.append(obj_id)
                tasks_in_flight[worker] -= 1
                map_bar.update(1)
            else:
                ready_workers.add(worker)
                map_bar.set_description(
                    "Map Progress ({} actors {} pending)".format(
                        len(ready_workers),
                        len(workers) - len(ready_workers)))

            # Schedule a new task.
            while (blocks_in and tasks_in_flight[worker] <
                   self.max_tasks_in_flight_per_actor):
                block, meta = blocks_in.pop()
                if context.block_splitting_enabled:
                    ref = worker.map_block_split.remote(
                        block, meta.input_files)
                else:
                    ref, meta_ref = worker.map_block_nosplit.remote(
                        block, meta.input_files)
                    metadata_mapping[ref] = meta_ref
                tasks[ref] = worker
                block_indices[ref] = len(blocks_in)
                tasks_in_flight[worker] += 1

        map_bar.close()
        new_blocks, new_metadata = [], []
        # Put blocks in input order.
        results.sort(key=block_indices.get)
        if context.block_splitting_enabled:
            for result in ray.get(results):
                for block, metadata in result:
                    new_blocks.append(block)
                    new_metadata.append(metadata)
        else:
            for block in results:
                new_blocks.append(block)
                new_metadata.append(metadata_mapping[block])
            new_metadata = ray.get(new_metadata)
        return BlockList(new_blocks, new_metadata)
Example #4
0
    def execute(
        self,
        input_blocks: BlockList,
        output_num_blocks: int,
        clear_input_blocks: bool,
        *,
        map_ray_remote_args: Optional[Dict[str, Any]] = None,
        reduce_ray_remote_args: Optional[Dict[str, Any]] = None,
        merge_factor: int = 2,
    ) -> Tuple[BlockList, Dict[str, List[BlockMetadata]]]:
        logger.info("Using experimental push-based shuffle.")
        # TODO(swang): For jobs whose reduce work is heavier than the map work,
        # we should support fractional merge factors.
        # TODO(swang): For large jobs, we should try to choose the merge factor
        # automatically, e.g., by running one test round of map and merge tasks
        # and comparing their run times.
        # TODO(swang): Add option to automatically reduce write amplification
        # during map-merge stage, by limiting how many partitions can be
        # processed concurrently.
        input_blocks_list = input_blocks.get_blocks()
        # Preemptively clear the blocks list since we will incrementally delete
        # the last remaining references as we submit the dependent map tasks
        # during the map-merge stage.
        if clear_input_blocks:
            input_blocks.clear()

        if map_ray_remote_args is None:
            map_ray_remote_args = {}
        if reduce_ray_remote_args is None:
            reduce_ray_remote_args = {}
        # The placement strategy for reduce tasks is overwritten to colocate
        # them with their inputs from the merge stage, so remove any
        # pre-specified scheduling strategy here.
        reduce_ray_remote_args = reduce_ray_remote_args.copy()
        reduce_ray_remote_args.pop("scheduling_strategy", None)

        map_fn = self._map_partition
        merge_fn = self._merge

        def map_partition(*args, **kwargs):
            return map_fn(self.map, *args, **kwargs)

        def merge(*args, **kwargs):
            return merge_fn(self.reduce, *args, **kwargs)

        shuffle_map = cached_remote_fn(map_partition)
        shuffle_merge = cached_remote_fn(merge)

        def submit_map_task(arg):
            mapper_idx, block = arg
            # NOTE(swang): Results are shuffled between map and merge tasks, so
            # there is no advantage to colocating specific map and merge tasks.
            # Therefore, we do not specify a node affinity policy for map tasks
            # in case the caller or Ray has a better scheduling strategy, e.g.,
            # based on data locality.
            map_result = shuffle_map.options(
                **map_ray_remote_args,
                num_returns=1 + schedule.num_merge_tasks_per_round,
            ).remote(
                mapper_idx,
                block,
                output_num_blocks,
                schedule,
                *self._map_args,
            )
            metadata_ref = map_result.pop(0)
            return metadata_ref, map_result

        def submit_merge_task(arg):
            merge_idx, map_results = arg
            num_merge_returns = schedule.get_num_reducers_per_merge_idx(merge_idx)
            merge_result = shuffle_merge.options(
                num_returns=1 + num_merge_returns,
                **schedule.get_merge_task_options(merge_idx),
            ).remote(
                *map_results,
                reduce_args=self._reduce_args,
            )
            metadata_ref = merge_result.pop(0)
            return metadata_ref, merge_result

        # Compute all constants used for task scheduling.
        num_cpus_per_node_map = _get_num_cpus_per_node_map()
        schedule = self._compute_shuffle_schedule(
            num_cpus_per_node_map,
            len(input_blocks_list),
            merge_factor,
            output_num_blocks,
        )

        # ObjectRef results from the last round of tasks. Used to add
        # backpressure during pipelining of map and merge tasks.
        last_map_metadata_results = []
        last_merge_metadata_results = []
        # Final outputs from the map-merge stage.
        # This is a map from merge task index to a nested list of merge results
        # (ObjectRefs). Each merge task index corresponds to a partition of P
        # final reduce tasks.
        all_merge_results = [[] for _ in range(schedule.num_merge_tasks_per_round)]
        shuffle_map_metadata = []
        shuffle_merge_metadata = []
        map_bar = ProgressBar("Shuffle Map", position=0, total=len(input_blocks_list))

        # Execute the map-merge stage. This submits tasks in rounds of M map
        # tasks and N merge tasks each. Task execution between map and merge is
        # pipelined, so that while executing merge for one round of inputs, we
        # also execute the map tasks for the following round.
        input_blocks_list = list(enumerate(input_blocks_list))
        while input_blocks_list:
            # Execute one round of the map stage.
            # Pop from the inputs so that we can clear the memory ASAP.
            round_input_blocks = []
            try:
                for _ in range(schedule.num_map_tasks_per_round):
                    round_input_blocks.append(input_blocks_list.pop(0))
            except IndexError:
                pass
            (
                prev_map_metadata,
                last_map_metadata_results,
                map_results,
            ) = _execute_pipelined_stage(
                submit_map_task,
                last_map_metadata_results,
                round_input_blocks,
                progress_bar=map_bar,
            )
            shuffle_map_metadata += prev_map_metadata

            # Shuffle the map results for the merge tasks.
            merge_args = [
                (merge_idx, [map_result.pop(0) for map_result in map_results])
                for merge_idx in range(schedule.num_merge_tasks_per_round)
            ]
            assert all([not map_result for map_result in map_results])
            # Execute one round of the merge stage.
            (
                prev_merge_metadata,
                last_merge_metadata_results,
                merge_results,
            ) = _execute_pipelined_stage(
                submit_merge_task,
                last_merge_metadata_results,
                merge_args,
            )
            shuffle_merge_metadata += prev_merge_metadata
            for merge_idx, merge_result in enumerate(merge_results):
                all_merge_results[merge_idx].append(merge_result)
            del merge_results

        # Wait for last map and merge tasks to finish.
        prev_map_metadata, _, _ = _execute_pipelined_stage(
            None, last_map_metadata_results, [], progress_bar=map_bar
        )
        shuffle_map_metadata += prev_map_metadata
        map_bar.close()
        prev_merge_metadata, _, _ = _execute_pipelined_stage(
            None, last_merge_metadata_results, []
        )
        shuffle_merge_metadata += prev_merge_metadata

        # Execute and wait for the reduce stage.
        new_metadata, new_blocks = self._execute_reduce_stage(
            output_num_blocks, schedule, reduce_ray_remote_args, all_merge_results
        )

        stats = {
            "map": shuffle_map_metadata,
            "merge": shuffle_merge_metadata,
            "reduce": new_metadata,
        }

        return BlockList(list(new_blocks), list(new_metadata)), stats
Example #5
0
    def _apply(
        self,
        block_fn: BlockTransform,
        remote_args: dict,
        block_list: BlockList,
        clear_input_blocks: bool,
        name: Optional[str] = None,
        fn: Optional[UDF] = None,
        fn_args: Optional[Iterable[Any]] = None,
        fn_kwargs: Optional[Dict[str, Any]] = None,
        fn_constructor_args: Optional[Iterable[Any]] = None,
        fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
    ) -> BlockList:
        """Note: this is not part of the Dataset public API."""
        if fn_args is None:
            fn_args = tuple()
        if fn_kwargs is None:
            fn_kwargs = {}
        if fn_constructor_args is None:
            fn_constructor_args = tuple()
        if fn_constructor_kwargs is None:
            fn_constructor_kwargs = {}

        context = DatasetContext.get_current()

        blocks_in = block_list.get_blocks_with_metadata()

        # Early release block references.
        if clear_input_blocks:
            block_list.clear()

        orig_num_blocks = len(blocks_in)
        results = []
        if name is None:
            name = "map"
        name = name.title()
        map_bar = ProgressBar(name, total=orig_num_blocks)

        class BlockWorker:
            def __init__(
                self,
                *fn_constructor_args: Any,
                **fn_constructor_kwargs: Any,
            ):
                if not isinstance(fn, CallableClass):
                    if fn_constructor_args or fn_constructor_kwargs:
                        raise ValueError(
                            "fn_constructor_{kw}args only valid for CallableClass "
                            f"UDFs, but got: {fn}"
                        )
                    self.fn = fn
                else:
                    self.fn = fn(*fn_constructor_args, **fn_constructor_kwargs)

            def ready(self):
                return "ok"

            def map_block_split(
                self,
                block: Block,
                input_files: List[str],
                *fn_args,
                **fn_kwargs,
            ) -> BlockPartition:
                return _map_block_split(
                    block, block_fn, input_files, self.fn, *fn_args, **fn_kwargs
                )

            @ray.method(num_returns=2)
            def map_block_nosplit(
                self,
                block: Block,
                input_files: List[str],
                *fn_args,
                **fn_kwargs,
            ) -> Tuple[Block, BlockMetadata]:
                return _map_block_nosplit(
                    block, block_fn, input_files, self.fn, *fn_args, **fn_kwargs
                )

        if "num_cpus" not in remote_args:
            remote_args["num_cpus"] = 1

        if "scheduling_strategy" not in remote_args:
            ctx = DatasetContext.get_current()
            if ctx.scheduling_strategy == DEFAULT_SCHEDULING_STRATEGY:
                remote_args["scheduling_strategy"] = "SPREAD"
            else:
                remote_args["scheduling_strategy"] = ctx.scheduling_strategy

        BlockWorker = ray.remote(**remote_args)(BlockWorker)

        workers = [
            BlockWorker.remote(*fn_constructor_args, **fn_constructor_kwargs)
            for _ in range(self.min_size)
        ]
        tasks = {w.ready.remote(): w for w in workers}
        tasks_in_flight = collections.defaultdict(int)
        metadata_mapping = {}
        block_indices = {}
        ready_workers = set()

        try:
            while len(results) < orig_num_blocks:
                ready, _ = ray.wait(
                    list(tasks.keys()), timeout=0.01, num_returns=1, fetch_local=False
                )
                if not ready:
                    if (
                        len(workers) < self.max_size
                        and len(ready_workers) / len(workers)
                        > self.ready_to_total_workers_ratio
                    ):
                        w = BlockWorker.remote(
                            *fn_constructor_args, **fn_constructor_kwargs
                        )
                        workers.append(w)
                        tasks[w.ready.remote()] = w
                        map_bar.set_description(
                            "Map Progress ({} actors {} pending)".format(
                                len(ready_workers), len(workers) - len(ready_workers)
                            )
                        )
                    continue

                [obj_id] = ready
                worker = tasks.pop(obj_id)

                # Process task result.
                if worker in ready_workers:
                    results.append(obj_id)
                    tasks_in_flight[worker] -= 1
                    map_bar.update(1)
                else:
                    ready_workers.add(worker)
                    map_bar.set_description(
                        "Map Progress ({} actors {} pending)".format(
                            len(ready_workers), len(workers) - len(ready_workers)
                        )
                    )

                # Schedule a new task.
                while (
                    blocks_in
                    and tasks_in_flight[worker] < self.max_tasks_in_flight_per_actor
                ):
                    block, meta = blocks_in.pop()
                    if context.block_splitting_enabled:
                        ref = worker.map_block_split.remote(
                            block,
                            meta.input_files,
                            *fn_args,
                            **fn_kwargs,
                        )
                    else:
                        ref, meta_ref = worker.map_block_nosplit.remote(
                            block,
                            meta.input_files,
                            *fn_args,
                            **fn_kwargs,
                        )
                        metadata_mapping[ref] = meta_ref
                    tasks[ref] = worker
                    block_indices[ref] = len(blocks_in)
                    tasks_in_flight[worker] += 1

            map_bar.close()
            self.num_workers += len(workers)
            new_blocks, new_metadata = [], []
            # Put blocks in input order.
            results.sort(key=block_indices.get)
            if context.block_splitting_enabled:
                for result in ray.get(results):
                    for block, metadata in result:
                        new_blocks.append(block)
                        new_metadata.append(metadata)
            else:
                for block in results:
                    new_blocks.append(block)
                    new_metadata.append(metadata_mapping[block])
                new_metadata = ray.get(new_metadata)
            return BlockList(new_blocks, new_metadata)

        except Exception as e:
            try:
                for worker in workers:
                    ray.kill(worker)
            except Exception as err:
                logger.exception(f"Error killing workers: {err}")
            finally:
                raise e
Example #6
0
    def execute(
        self,
        input_blocks: BlockList,
        output_num_blocks: int,
        clear_input_blocks: bool,
        *,
        map_ray_remote_args: Optional[Dict[str, Any]] = None,
        reduce_ray_remote_args: Optional[Dict[str, Any]] = None,
        merge_factor: int = 2,
    ) -> Tuple[BlockList, Dict[str, List[BlockMetadata]]]:
        logger.info("Using experimental push-based shuffle.")
        # TODO(swang): For jobs whose reduce work is heavier than the map work,
        # we should support fractional merge factors.
        # TODO(swang): For large jobs, we should try to choose the merge factor
        # automatically, e.g., by running one test round of map and merge tasks
        # and comparing their run times.
        # TODO(swang): Add option to automatically reduce write amplification
        # during map-merge stage, by limiting how many partitions can be
        # processed concurrently.
        input_blocks_list = input_blocks.get_blocks()
        # Preemptively clear the blocks list since we will incrementally delete
        # the last remaining references as we submit the dependent map tasks
        # during the map-merge stage.
        if clear_input_blocks:
            input_blocks.clear()

        if map_ray_remote_args is None:
            map_ray_remote_args = {}
        if reduce_ray_remote_args is None:
            reduce_ray_remote_args = {}
        # The placement strategy for reduce tasks is overwritten to colocate
        # them with their inputs from the merge stage, so remove any
        # pre-specified scheduling strategy here.
        reduce_ray_remote_args = reduce_ray_remote_args.copy()
        reduce_ray_remote_args.pop("scheduling_strategy", None)

        # Compute all constants used for task scheduling.
        num_cpus_per_node_map = _get_num_cpus_per_node_map()
        stage = self._compute_shuffle_schedule(
            num_cpus_per_node_map,
            len(input_blocks_list),
            merge_factor,
            output_num_blocks,
        )

        map_fn = self._map_partition
        merge_fn = self._merge

        def map_partition(*args, **kwargs):
            return map_fn(self.map, *args, **kwargs)

        def merge(*args, **kwargs):
            return merge_fn(self.reduce, *args, **kwargs)

        shuffle_map = cached_remote_fn(map_partition)
        shuffle_map = shuffle_map.options(
            **map_ray_remote_args,
            num_returns=1 + stage.num_merge_tasks_per_round,
        )

        map_stage_iter = _MapStageIterator(
            input_blocks_list,
            shuffle_map,
            [output_num_blocks, stage.merge_schedule, *self._map_args],
        )
        map_bar = ProgressBar("Shuffle Map",
                              position=0,
                              total=len(input_blocks_list))
        map_stage_executor = _PipelinedStageExecutor(
            map_stage_iter,
            stage.num_map_tasks_per_round,
            progress_bar=map_bar)

        shuffle_merge = cached_remote_fn(merge)
        merge_stage_iter = _MergeStageIterator(map_stage_iter, shuffle_merge,
                                               stage, self._reduce_args)
        merge_stage_executor = _PipelinedStageExecutor(
            merge_stage_iter,
            stage.num_merge_tasks_per_round,
            max_concurrent_rounds=2)

        # Execute the map-merge stage. This submits tasks in rounds of M map
        # tasks and N merge tasks each. Task execution between map and merge is
        # pipelined, so that while executing merge for one round of inputs, we
        # also execute the map tasks for the following round.
        map_done = False
        merge_done = False
        map_stage_metadata = []
        merge_stage_metadata = []
        while not (map_done and merge_done):
            try:
                map_stage_metadata += next(map_stage_executor)
            except StopIteration:
                map_done = True
                break

            try:
                merge_stage_metadata += next(merge_stage_executor)
            except StopIteration:
                merge_done = True
                break

        map_bar.close()
        all_merge_results = merge_stage_iter.pop_merge_results()

        # Execute and wait for the reduce stage.
        reduce_bar = ProgressBar("Shuffle Reduce", total=output_num_blocks)
        shuffle_reduce = cached_remote_fn(self.reduce)
        reduce_stage_iter = _ReduceStageIterator(
            stage,
            shuffle_reduce,
            all_merge_results,
            reduce_ray_remote_args,
            self._reduce_args,
        )

        max_reduce_tasks_in_flight = output_num_blocks
        ctx = DatasetContext.get_current()
        if ctx.pipeline_push_based_shuffle_reduce_tasks:
            # If pipelining is enabled, we should still try to utilize all
            # cores.
            max_reduce_tasks_in_flight = min(
                max_reduce_tasks_in_flight,
                sum(num_cpus_per_node_map.values()))

        reduce_stage_executor = _PipelinedStageExecutor(
            reduce_stage_iter,
            max_reduce_tasks_in_flight,
            max_concurrent_rounds=2,
            progress_bar=reduce_bar,
        )
        reduce_stage_metadata = []
        while True:
            try:
                reduce_stage_metadata += next(reduce_stage_executor)
            except StopIteration:
                break

        new_blocks = reduce_stage_iter.pop_reduce_results()
        sorted_blocks = [(block[0], block[1], reduce_stage_metadata[i])
                         for i, block in enumerate(new_blocks)]
        sorted_blocks.sort(key=lambda x: x[0])
        _, new_blocks, reduce_stage_metadata = zip(*sorted_blocks)
        del sorted_blocks

        assert (
            len(new_blocks) == output_num_blocks
        ), f"Expected {output_num_blocks} outputs, produced {len(new_blocks)}"
        reduce_bar.close()

        stats = {
            "map": map_stage_metadata,
            "merge": merge_stage_metadata,
            "reduce": reduce_stage_metadata,
        }

        return BlockList(list(new_blocks), list(reduce_stage_metadata)), stats