def _get_read_tasks( ds: Datasource, ctx: DatasetContext, cur_pg: Optional[PlacementGroup], parallelism: int, kwargs: dict, ) -> (int, int, List[ReadTask]): """Generates read tasks. Args: ds: Datasource to read from. ctx: Dataset config to use. cur_pg: The current placement group, if any. parallelism: The user-requested parallelism, or -1 for autodetection. kwargs: Additional kwargs to pass to the reader. Returns: Request parallelism from the datasource, the min safe parallelism to avoid OOM, and the list of read tasks generated. """ kwargs = _unwrap_arrow_serialization_workaround(kwargs) DatasetContext._set_current(ctx) reader = ds.create_reader(**kwargs) requested_parallelism, min_safe_parallelism = _autodetect_parallelism( parallelism, cur_pg, DatasetContext.get_current(), reader) return ( requested_parallelism, min_safe_parallelism, reader.get_read_tasks(requested_parallelism), )
def _prepare_read(ds: Datasource, ctx: DatasetContext, parallelism: int, kwargs: dict) -> List[ReadTask]: kwargs = _unwrap_s3_filesystem_workaround(kwargs) DatasetContext._set_current(ctx) return ds.prepare_read(parallelism, **kwargs)