def read_datasource( datasource: Datasource[T], *, parallelism: int = 200, ray_remote_args: Dict[str, Any] = None, **read_args, ) -> Dataset[T]: """Read a dataset from a custom data source. Args: datasource: The datasource to read data from. parallelism: The requested parallelism of the read. Parallelism may be limited by the available partitioning of the datasource. read_args: Additional kwargs to pass to the datasource impl. ray_remote_args: kwargs passed to ray.remote in the read tasks. Returns: Dataset holding the data read from the datasource. """ ctx = DatasetContext.get_current() # TODO(ekl) remove this feature flag. force_local = "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ pa_ds = _lazy_import_pyarrow_dataset() if pa_ds: partitioning = read_args.get("dataset_kwargs", {}).get("partitioning", None) if isinstance(partitioning, pa_ds.Partitioning): logger.info( "Forcing local metadata resolution since the provided partitioning " f"{partitioning} is not serializable." ) force_local = True if force_local: read_tasks = datasource.prepare_read(parallelism, **read_args) else: # Prepare read in a remote task so that in Ray client mode, we aren't # attempting metadata resolution from the client machine. prepare_read = cached_remote_fn( _prepare_read, retry_exceptions=False, num_cpus=0 ) read_tasks = ray.get( prepare_read.remote( datasource, ctx, parallelism, _wrap_arrow_serialization_workaround(read_args), ) ) if len(read_tasks) < parallelism and ( len(read_tasks) < ray.available_resources().get("CPU", parallelism) // 2 ): logger.warning( "The number of blocks in this dataset ({}) limits its parallelism to {} " "concurrent tasks. This is much less than the number of available " "CPU slots in the cluster. Use `.repartition(n)` to increase the number of " "dataset blocks.".format(len(read_tasks), len(read_tasks)) ) if ray_remote_args is None: ray_remote_args = {} if ( "scheduling_strategy" not in ray_remote_args and ctx.scheduling_strategy == DEFAULT_SCHEDULING_STRATEGY ): ray_remote_args["scheduling_strategy"] = "SPREAD" block_list = LazyBlockList(read_tasks, ray_remote_args=ray_remote_args) block_list.compute_first_block() block_list.ensure_metadata_for_first_block() return Dataset( ExecutionPlan(block_list, block_list.stats()), 0, False, )
def read_datasource( datasource: Datasource[T], *, parallelism: int = 200, ray_remote_args: Dict[str, Any] = None, _spread_resource_prefix: Optional[str] = None, **read_args, ) -> Dataset[T]: """Read a dataset from a custom data source. Args: datasource: The datasource to read data from. parallelism: The requested parallelism of the read. Parallelism may be limited by the available partitioning of the datasource. read_args: Additional kwargs to pass to the datasource impl. ray_remote_args: kwargs passed to ray.remote in the read tasks. Returns: Dataset holding the data read from the datasource. """ # TODO(ekl) remove this feature flag. force_local = "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ pa_ds = _lazy_import_pyarrow_dataset() if pa_ds: partitioning = read_args.get("dataset_kwargs", {}).get("partitioning", None) if isinstance(partitioning, pa_ds.Partitioning): logger.info( "Forcing local metadata resolution since the provided partitioning " f"{partitioning} is not serializable." ) force_local = True if force_local: read_tasks = datasource.prepare_read(parallelism, **read_args) else: # Prepare read in a remote task so that in Ray client mode, we aren't # attempting metadata resolution from the client machine. ctx = DatasetContext.get_current() prepare_read = cached_remote_fn( _prepare_read, retry_exceptions=False, num_cpus=0 ) read_tasks = ray.get( prepare_read.remote( datasource, ctx, parallelism, _wrap_arrow_serialization_workaround(read_args), ) ) context = DatasetContext.get_current() stats_actor = get_or_create_stats_actor() stats_uuid = uuid.uuid4() stats_actor.record_start.remote(stats_uuid) def remote_read(i: int, task: ReadTask, stats_actor) -> MaybeBlockPartition: DatasetContext._set_current(context) stats = BlockExecStats.builder() # Execute the read task. block = task() if context.block_splitting_enabled: metadata = task.get_metadata() metadata.exec_stats = stats.build() else: metadata = BlockAccessor.for_block(block).get_metadata( input_files=task.get_metadata().input_files, exec_stats=stats.build() ) stats_actor.record_task.remote(stats_uuid, i, metadata) return block if ray_remote_args is None: ray_remote_args = {} if "scheduling_strategy" not in ray_remote_args: ray_remote_args["scheduling_strategy"] = "SPREAD" remote_read = cached_remote_fn(remote_read) if _spread_resource_prefix is not None: if context.optimize_fuse_stages: logger.warning( "_spread_resource_prefix has no effect when optimize_fuse_stages " "is enabled. Tasks are spread by default." ) # Use given spread resource prefix for round-robin resource-based # scheduling. nodes = ray.nodes() resource_iter = _get_spread_resources_iter( nodes, _spread_resource_prefix, ray_remote_args ) else: # If no spread resource prefix given, yield an empty dictionary. resource_iter = itertools.repeat({}) calls: List[Callable[[], ObjectRef[MaybeBlockPartition]]] = [] metadata: List[BlockPartitionMetadata] = [] for i, task in enumerate(read_tasks): calls.append( lambda i=i, task=task, resources=next(resource_iter): remote_read.options( **ray_remote_args, resources=resources ).remote(i, task, stats_actor) ) metadata.append(task.get_metadata()) block_list = LazyBlockList(calls, metadata) # TODO(ekl) consider refactoring LazyBlockList to take read_tasks explicitly. block_list._read_tasks = read_tasks block_list._read_remote_args = ray_remote_args # Get the schema from the first block synchronously. if metadata and metadata[0].schema is None: block_list.ensure_schema_for_first_block() stats = DatasetStats( stages={"read": metadata}, parent=None, stats_actor=stats_actor, stats_uuid=stats_uuid, ) return Dataset( ExecutionPlan(block_list, stats), 0, False, )