def _store_known_sized_ds(self, fs: fsspec.AbstractFileSystem, path: str) -> int: client = get_client() worker_count = sum(client.ncores().values()) # chunksize = min(*[t.chunksize for t in self._tensors.values()]) chunksize = (min( *[t.chunksize for t in self._tensors.values()]) if len(self._tensors) > 1 else next(iter(self._tensors.values())).chunksize) cnt = len(self) collected = {el: None for el in self._tensors.keys()} collected_offset = {el: 0 for el in collected} step = worker_count * chunksize for i in range(0, cnt, step): batch_count = min(step, cnt - i) lasttime = True if i + step >= cnt else False persisted = client.persist([ self._tensors[key]._array[i:i + batch_count] for key in collected ]) persisted = {key: persisted[j] for j, key in enumerate(collected)} tasks = [] for el, arr in persisted.items(): if collected[el] is None: collected[el] = arr else: collected[el] = _dask_concat([collected[el], arr]) c = collected[el] chunksize_ = self._tensors[el].chunksize codec = codec_from_name(self._tensors[el].dcompress) if len(c) >= chunksize_ or lasttime: jcnt = len( c) - len(c) % chunksize_ if not lasttime else len(c) for j in range(0, jcnt, chunksize_): tasks += [ dask.delayed(_numpy_saver)( fs, f"{path}/{el}/{collected_offset[el] + j}.npy", collected[el][j:j + chunksize_], codec, ) ] collected_offset[el] += jcnt collected[el] = collected[el][jcnt:] client.gather(client.compute(tasks)) count = set(collected_offset.values()) assert ( len(count) == 1 ), "All tensors should be the same size to be stored in the same dataset" return next(iter(count))
def _store_unknown_sized_ds(self, fs: fsspec.AbstractFileSystem, path: str) -> int: client = get_client() worker_count = sum(client.ncores().values()) # worker_count = 1 chunks = {key: t._delayed_objs for key, t in self._tensors.items()} chunk_count = [len(items) for _, items in chunks.items()] assert ( len(set(chunk_count)) == 1 ), "Number of chunks in each tensor should be the same to be able to store dataset" chunk_count = chunk_count[0] count = 0 collected = {el: None for el in self._tensors.keys()} collected_offset = {el: 0 for el in collected} # max_chunksize = max(*[t.chunksize for t in self._tensors]) for i in range(0, chunk_count, worker_count): batch_count = min(i + worker_count, chunk_count) - i lasttime = True if i + worker_count >= chunk_count else False tasks = { key: delayed_objs[i : i + batch_count] for key, delayed_objs in chunks.items() } # logger.info(tasks) tasks, keys = _dict_to_tuple(tasks) # dask.visualize( # tasks, filename=f"./data/tasks/{i}", optimize_graph=True, # ) persisted = client.persist(tasks) persisted = _tuple_to_dict(persisted, keys) # for j in range(batch_count): # assert ( # len( # { # # len(objs[j]) # # client.submit() # dask.delayed(len)(objs[j]).compute() # for objs in persisted.values() # } # ) # == 1 # ), "All numpy arrays returned from call should have same len" lens = { key: [dask.delayed(len)(objs[j]) for j in range(batch_count)] for key, objs in persisted.items() } lens, keys = _dict_to_tuple(lens) lens = client.gather(client.compute(lens)) lens = _tuple_to_dict(lens, keys) for key, objs in persisted.items(): arr = _dask_concat( [ dask.array.from_delayed( obj, dtype=self._tensors[key].dtype, shape=(lens[key][i],) + tuple(self._tensors[key].shape[1:]), ) for i, obj in enumerate(objs) ] ) if collected[key] is None: collected[key] = arr else: collected[key] = _dask_concat([collected[key], arr]) # tasks = [obj for key, objs in persisted.items() for obj in objs] tasks = [] for key in list(collected.keys()): c = collected[key] chunksize = self._tensors[key].chunksize codec = codec_from_name(self._tensors[key].dcompress) cnt = len(c) - len(c) % chunksize if not lasttime else len(c) for i in range(0, cnt, chunksize): tasks += [ dask.delayed(_numpy_saver)( fs, f"{path}/{key}/{collected_offset[key] + i}.npy", c[i : i + chunksize], codec, ) ] collected_offset[key] += cnt collected[key] = collected[key][cnt:] client.gather(client.compute(tasks)) count = set(collected_offset.values()) assert ( len(count) == 1 ), "All tensors should be the same size to be stored in the same dataset" return next(iter(count))