def test_map_partitions_args_kwargs(): x = [random.randint(-100, 100) for i in range(100)] y = [random.randint(-100, 100) for i in range(100)] dx = db.from_sequence(x, npartitions=10) dy = db.from_sequence(y, npartitions=10) def maximum(x, y=0): y = repeat(y) if isinstance(y, int) else y return [max(a, b) for (a, b) in zip(x, y)] sol = maximum(x, y=10) assert db.map_partitions(maximum, dx, y=10).compute() == sol assert dx.map_partitions(maximum, y=10).compute() == sol assert dx.map_partitions(maximum, 10).compute() == sol sol = maximum(x, y) assert db.map_partitions(maximum, dx, dy).compute() == sol assert dx.map_partitions(maximum, y=dy).compute() == sol assert dx.map_partitions(maximum, dy).compute() == sol dy_mean = dy.mean().apply(int) sol = maximum(x, int(sum(y) / len(y))) assert dx.map_partitions(maximum, y=dy_mean).compute() == sol assert dx.map_partitions(maximum, dy_mean).compute() == sol dy_mean = dask.delayed(dy_mean) assert dx.map_partitions(maximum, y=dy_mean).compute() == sol assert dx.map_partitions(maximum, dy_mean).compute() == sol
def drop_empty_bag_partitions(bag): """ When bags are created by filtering or grouping from a different bag, it retains the original bag's partition count, even if a lot of the partitions become empty. Those extra partitions add overhead, so it's nice to discard them. This function drops the empty partitions. Inspired by: https://stackoverflow.com/questions/47812785/remove-empty-partitions-in-dask """ bag = bag.persist() def get_len(partition): # If the bag is the result of bag.filter(), # then each partition is actually a 'filter' object, # which has no __len__. # In that case, we must convert it to a list first. if hasattr(partition, '__len__'): return len(partition) return len(list(partition)) partition_lengths = bag.map_partitions(get_len).compute() # Convert bag partitions into a list of 'delayed' objects lengths_and_partitions = zip(partition_lengths, bag.to_delayed()) # Drop the ones with empty partitions partitions = (p for l, p in lengths_and_partitions if l > 0) # Convert from list of delayed objects back into a Bag. return dask.bag.from_delayed(partitions)
def persist_and_execute(bag, description=None, logger=None, optimize_graph=True): """ Persist and execute the given dask.Bag. The persisted Bag is returned. """ assert isinstance(bag, Bag) if logger and description: logger.info(f"{description}...") with Timer() as timer: bag = bag.persist(optimize_graph=optimize_graph) count = bag.count().compute() # force eval parts = bag.npartitions partition_counts = bag.map_partitions( lambda part: [sum(1 for _ in part)]).compute() histogram = defaultdict(lambda: 0) for c in partition_counts: histogram[c] += 1 histogram = dict(histogram) if logger and description: logger.info( f"{description} (N={count}, P={parts}, P_hist={histogram}) took {timer.timedelta}" ) return bag
def prepare_transfer(self, sources: Sequence[str], overwrite: bool): source_bag = bag.from_sequence(sources, partition_size=self.partition_size) dest_bag = source_bag.map(self.prepare_source) result = bag.map_partitions( partial(iterput, profile=self.profile, overwrite=overwrite), source_bag, dest_bag, ) return result
def s3put( dest_root: str, source_path: str, endswith: Optional[str] = "", tags: Optional[Dict] = None, partition_size: Optional[int] = None, ): sources = tuple(map(str, list_files_parallel(source_path))) dests = tuple( str(Path(dest_root) / Path(f).relative_to(source_path)) for f in sources ) source_bag = bag.from_sequence(sources, partition_size=partition_size) dest_bag = bag.from_sequence(dests, partition_size=partition_size) tag_bag = bag.from_sequence((tags,) * len(sources), partition_size=partition_size) return bag.map_partitions(iterput, source_bag, dest_bag, tag_bag)
def glom(bag): return bag.map_partitions(lambda i: [list(i)])
def bag_glom(bag): return bag.map_partitions(lambda i: [dask.bag.from_sequence(i)])