예제 #1
0
    def _split(self, n: int, splitter: Callable[[Dataset],
                                                "DatasetPipeline[T]"]):

        resources = {}
        if not ray.util.client.ray.is_connected():
            # Pin the coordinator (and any child actors) to the local node to avoid
            # errors during node failures. If the local node dies, then the driver
            # will fate-share with the coordinator anyway.
            resources["node:{}".format(
                ray.util.get_node_ip_address())] = 0.0001

        coordinator = PipelineSplitExecutorCoordinator.options(
            resources=resources,
            placement_group=None,
        ).remote(self, n, splitter, DatasetContext.get_current())
        if self._executed[0]:
            raise RuntimeError("Pipeline cannot be read multiple times.")
        self._executed[0] = True

        class SplitIterator:
            def __init__(self, split_index, coordinator):
                self.split_index = split_index
                self.coordinator = coordinator
                self.warn_threshold = 100
                self.wait_delay_s = 0.1

            def __iter__(self):
                return self

            def __next__(self):
                ds = None
                tries = 0
                while ds is None:
                    ds = ray.get(
                        self.coordinator.next_dataset_if_ready.remote(
                            self.split_index))
                    # Wait for other shards to catch up reading.
                    if not ds:
                        time.sleep(self.wait_delay_s)
                        tries += 1
                    if tries > self.warn_threshold:
                        print("Warning: reader on shard {} of the pipeline "
                              "has been blocked more than {}s waiting for "
                              "other readers to catch up. All pipeline shards "
                              "must be read from concurrently.".format(
                                  self.split_index,
                                  self.wait_delay_s * self.warn_threshold,
                              ))
                        self.warn_threshold *= 2
                return lambda: ds

        return [
            # Disable progress bars for the split readers since they would
            # overwhelm the console.
            DatasetPipeline(
                SplitIterator(idx, coordinator),
                length=self._length,
                progress_bars=False,
            ) for idx in range(n)
        ]
예제 #2
0
    def _split(self, n: int, splitter: Callable[[Dataset],
                                                "DatasetPipeline[T]"]):

        coordinator = PipelineSplitExecutorCoordinator.remote(
            self, n, splitter, DatasetContext.get_current())
        if self._executed[0]:
            raise RuntimeError("Pipeline cannot be read multiple times.")
        self._executed[0] = True

        class SplitIterator:
            def __init__(self, split_index, coordinator):
                self.split_index = split_index
                self.coordinator = coordinator
                self.warn_threshold = 100
                self.wait_delay_s = 0.1

            def __iter__(self):
                return self

            def __next__(self):
                ds = None
                tries = 0
                while ds is None:
                    ds = ray.get(
                        self.coordinator.next_dataset_if_ready.remote(
                            self.split_index))
                    # Wait for other shards to catch up reading.
                    if not ds:
                        time.sleep(self.wait_delay_s)
                        tries += 1
                    if tries > self.warn_threshold:
                        print("Warning: reader on shard {} of the pipeline "
                              "has been blocked more than {}s waiting for "
                              "other readers to catch up. All pipeline shards "
                              "must be read from concurrently.".format(
                                  self.split_index,
                                  self.wait_delay_s * self.warn_threshold,
                              ))
                        self.warn_threshold *= 2
                return lambda: ds

        return [
            # Disable progress bars for the split readers since they would
            # overwhelm the console.
            DatasetPipeline(
                SplitIterator(idx, coordinator),
                length=self._length,
                progress_bars=False,
            ) for idx in range(n)
        ]
예제 #3
0
    def _split(self, n: int,
               splitter: Callable[[Dataset], "DatasetPipeline[T]"]):

        coordinator = PipelineSplitExecutorCoordinator.remote(
            self, n, splitter)

        class SplitIterator:
            def __init__(self, split_index, coordinator):
                self.split_index = split_index
                self.coordinator = coordinator
                self.warn_threshold = 100
                self.wait_delay_s = 0.1

            def __iter__(self):
                return self

            def __next__(self):
                ds = None
                tries = 0
                while ds is None:
                    ds = ray.get(
                        self.coordinator.next_dataset_if_ready.remote(
                            self.split_index))
                    # Wait for other shards to catch up reading.
                    if not ds:
                        time.sleep(self.wait_delay_s)
                        tries += 1
                    if tries > self.warn_threshold:
                        print("Warning: shard {} of the pipeline has been "
                              "stalled more than {}s waiting for other shards "
                              "to catch up.".format(
                                  self.split_index,
                                  self.wait_delay_s * self.warn_threshold))
                        self.warn_threshold *= 2
                return lambda: ds

        return [
            # Disable progress bars for the split readers since they would
            # overwhelm the console.
            DatasetPipeline(
                SplitIterator(idx, coordinator), progress_bars=False)
            for idx in range(n)
        ]