示例#1
0
def create_dataset_pipeline(files, epochs, num_windows):
    if num_windows > 1:
        file_splits = np.array_split(files, num_windows)

        class Windower:
            def __init__(self):
                self.i = 0
                self.iterations = epochs * num_windows

            def __iter__(self):
                return self

            def __next__(self):
                if self.i >= self.iterations:
                    raise StopIteration()
                split = file_splits[self.i % num_windows]
                self.i += 1
                return lambda: ray.data.read_parquet(
                    list(split), _spread_resource_prefix="node:"
                )

        pipe = DatasetPipeline.from_iterable(Windower())
        pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:")
    else:
        ds = ray.data.read_parquet(files, _spread_resource_prefix="node:")
        pipe = ds.repeat(epochs)
        pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:")
    return pipe
示例#2
0
def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
    if num_windows > 1:
        num_rows = ray.data.read_parquet(
            files
        ).count()  # This should only read Parquet metadata.
        file_splits = np.array_split(files, num_windows)

        class Windower:
            def __init__(self):
                self.i = 0
                self.iterations = epochs * num_windows

            def __iter__(self):
                return self

            def __next__(self):
                if self.i >= self.iterations:
                    raise StopIteration()
                split = file_splits[self.i % num_windows]
                self.i += 1
                return lambda: ray.data.read_parquet(list(split))

        pipe = DatasetPipeline.from_iterable(Windower())
        split_indices = [
            i * num_rows // num_windows // num_workers for i in range(1, num_workers)
        ]
        pipe = pipe.random_shuffle_each_window()
        pipe_shards = pipe.split_at_indices(split_indices)
    else:
        ds = ray.data.read_parquet(files)
        pipe = ds.repeat(epochs)
        pipe = pipe.random_shuffle_each_window()
        pipe_shards = pipe.split(num_workers, equal=True)
    return pipe_shards
示例#3
0
def test_from_iterable(ray_start_regular_shared):
    pipe = DatasetPipeline.from_iterable(
        [lambda: ray.data.range(3), lambda: ray.data.range(2)])
    assert pipe.take() == [0, 1, 2, 0, 1]