def test_add_cell_metrics(): path = tempfile.mkdtemp() store = zarr.DirectoryStore(path) root = zarr.group(store, overwrite=True) test_data_location = "test/data/" test_metrics_file = test_data_location + "merged-cell-metrics.csv.gz" test_emptydrops_file = test_data_location + "empty_drops_result.csv" metrics_df = pd.read_csv(test_metrics_file, dtype=str) metrics_df = metrics_df.rename(columns={"Unnamed: 0": "cell_id"}) sample_cell_ids = metrics_df['cell_id'][0:10] target_test_code.add_cell_metrics(data_group=root, metrics_file=test_metrics_file, cell_ids=sample_cell_ids, emptydrops_file=test_emptydrops_file, verbose=True) # Read the results back store_read = zarr.open(zarr.DirectoryStore(path)) cell_metadata_float = store_read['cell_metadata_float'][:] cell_metadata_float_name = store_read['cell_metadata_float_name'][:] cell_metadata_bool = store_read['cell_metadata_bool'][:] cell_metadata_bool_name = store_read['cell_metadata_bool_name'][:] assert cell_metadata_float.shape == (10, 13) assert cell_metadata_float_name.shape == (13, ) assert cell_metadata_bool.shape == (10, 2) assert cell_metadata_bool_name.shape == (2, )
def random_v_random_capture(platform, key_len=16, N=10000): #may not be working scope,target = setup_device(platform) ktp = FixedVRandomText(key_len) store = zarr.DirectoryStore('data/{}-{}-{}.zarr'.format(platform,N,key_len)) root = zarr.group(store=store, overwrite=True) zwaves = root.zeros('traces/waves', shape=(2*N, scope.adc.samples), chunks=(2500, None), dtype='float64') ztextins = root.zeros('traces/textins', shape=(2*N, 16), chunks=(2500, None), dtype='uint8') waves = zwaves[:,:] textins = ztextins[:,:] for i in trange(2*N): key, text = ktp.next_group_B() trace = cw.capture_trace(scope, target, text, key) while trace is None: trace = cw.capture_trace(scope, target, text, key) if not verify_AES(text, key, trace.textout): raise ValueError("Encryption failed") #project.traces.append(trace) waves[i, :] = trace.wave textins[i, :] = np.array(text) zwaves[:,:] = waves[:,:] ztextins[:,:] = textins[:,:] print("Last encryption took {} samples".format(scope.adc.trig_count))
def netcdf_to_zarr(src, dst, axis=None, mode='serial', nested=False): """Summary Args: src (TYPE): Description dst (TYPE): Description axis (None, optional): Description mode (str, optional): Description nested (bool, optional): Description """ if isinstance(dst, str): if nested: local_store = zarr.NestedDirectoryStore(dst) else: local_store = zarr.DirectoryStore(dst) else: local_store = dst root = zarr.group(store=local_store, overwrite=True) for i, dname in enumerate(src): # cycling over groups, the first one is the root. for j, gname in enumerate(__get_groups(dname)): if j == 0: group = root ds = '' else: group = __set_group(gname, root) ds = dname if i == 0: __set_meta(ds + gname, group) __set_vars(ds + gname, group, mode) else: __append_vars(gname, group, axis, mode)
def load(cls, path: PathType): """Load existing DirectoryStore.""" zarr_store = zarr.DirectoryStore(path) group = zarr.group(store=zarr_store) xshape = cls._extract_xshape_from_zarr_group(group) zdim = cls._extract_zdim_from_zarr_group(group) return DirectoryStore(zdim=zdim, xshape=xshape, path=path)
def init_zarr(sample_id, path, file_format, schema_version): """Initializes the zarr output. Args: sample_id (str): sample or cell id path (str): path to the zarr output file_format (str): zarr file format [DirectoryStore, ZipStore] schema_version (str): version string of this output to allow for parsing of future changes Returns: root (zarr.hierarchy.Group): initialized zarr group """ store = None if file_format == "DirectoryStore": store = zarr.DirectoryStore(path) if file_format == "ZipStore": store = zarr.ZipStore(path, mode='w') # create the root group root = zarr.group(store, overwrite=True) #root.attrs['README'] = "The schema adopted in this zarr store may undergo changes in the future" root.attrs['sample_id'] = sample_id root.attrs['optimus_output_schema_version'] = schema_version # Create the expression_matrix group #root.create_group("expression_matrix", overwrite=True); return root
def open(self, mode: str = "r", cached: bool = True, cache_size_bytes: int = int(1e9)) -> "ChunkedDataset": """Opens a zarr dataset from disk from the path supplied in the constructor. Keyword Arguments: mode (str): Mode to open dataset in, default to read-only (default: {"r"}) cached (bool): Whether to cache files read from disk using a LRU cache. (default: {True}) cache_size_bytes (int): Size of cache in bytes (default: {1e9} (1GB)) Raises: Exception: When any of the expected arrays (frames, agents, scenes) is missing or the store couldn't be opened. """ if cached: self.root = zarr.open_group( store=zarr.LRUStoreCache(zarr.DirectoryStore(self.path), max_size=cache_size_bytes), mode=mode ) else: self.root = zarr.open_group(self.path, mode=mode) self.frames = self.root[FRAME_ARRAY_KEY] self.agents = self.root[AGENT_ARRAY_KEY] self.scenes = self.root[SCENE_ARRAY_KEY] try: self.tl_faces = self.root[TL_FACE_ARRAY_KEY] except KeyError: warnings.warn( f"{TL_FACE_ARRAY_KEY} not found in {self.path}! Traffic lights will be disabled", RuntimeWarning, stacklevel=2, ) self.tl_faces = np.empty((0,), dtype=TL_FACE_DTYPE) return self
def main(cfg): dat = HubmapDataset(cfg["data_dir"], cfg["out_dir"]) store = zarr.DirectoryStore(dat.path.out / "zarr" / cfg["version"] / f"db.zarr") database = zarr.group(store=store, overwrite=False) for id_, _ in tqdm(list(dat.get_inf("train").iterrows())): database.create_group(id_) img = dat.get_img(id_) msk = dat.get_msk(id_, "target") shape = dat.get_shape(id_) rescale = A.Resize(height=int(shape[0] / cfg["scale"]), width=int(shape[1] / cfg["scale"]), p=1.0) transformed = rescale(image=img, mask=msk) img, msk = transformed["image"], transformed["mask"] del transformed gc.collect() database[id_]["img"] = zarr.array(img, chunks=(cfg["chunk_size"], cfg["chunk_size"], 3)) database[id_]["target"] = zarr.array(msk, chunks=(cfg["chunk_size"], cfg["chunk_size"])) del img, msk gc.collect()
def open( self, mode: str = "r", cached: bool = True, cache_size_bytes: int = int(1e9)) -> "ChunkedDataset": """Opens a zarr dataset from disk from the path supplied in the constructor. :param mode: Mode to open dataset in, default to read-only (default: {"r"}) :param cached: Whether to cache files read from disk using a LRU cache. (default: {True}) :param cache_size_bytes: Size of cache in bytes (default: {1e9} (1GB)) """ if cached: self.root = zarr.open_group(store=zarr.LRUStoreCache( zarr.DirectoryStore(self.path), max_size=cache_size_bytes), mode=mode) else: self.root = zarr.open_group(self.path, mode=mode) self.frames = self.root[FRAME_ARRAY_KEY] self.agents = self.root[AGENT_ARRAY_KEY] self.scenes = self.root[SCENE_ARRAY_KEY] try: self.tl_faces = self.root[TL_FACE_ARRAY_KEY] except KeyError: # the real issue here is that frame doesn't have traffic_light_faces_index_interval warnings.warn( f"{TL_FACE_ARRAY_KEY} not found in {self.path}! " f"You won't be able to use this zarr into an Ego/AgentDataset", RuntimeWarning, stacklevel=2, ) self.tl_faces = np.empty((0, ), dtype=TL_FACE_DTYPE) return self
def prepare_zarr_group(dataset_id, dataset, store, table="MAIN"): dir_store = zarr.DirectoryStore(store) try: # Open in read/write, must exist group = zarr.open_group(store=dir_store, mode="r+") except zarr.errors.GroupNotFoundError: # Create, must not exist group = zarr.open_group(store=dir_store, mode="w-") group_name = f"{table}_{dataset_id}" ds_group = group.require_group(table).require_group(group_name) schema = DatasetSchema.from_dataset(dataset) for column, column_schema in schema.data_vars.items(): create_array(ds_group, column, column_schema, False) for column, column_schema in schema.coords.items(): create_array(ds_group, column, column_schema, True) ds_group.attrs.update({ **schema.attrs, DASKMS_ATTR_KEY: { "chunks": dict(dataset.chunks) } }) return ds_group
def init_zarr(sample_id, path, file_format): """Initializes the zarr output. Args: sample_id (str): sample or cell id path (str): path to the zarr output file_format (str): zarr file format [DirectoryStore, ZipStore] Returns: root (zarr.hierarchy.Group): initialized zarr group """ store = None if file_format == "DirectoryStore": store = zarr.DirectoryStore(path) if file_format == "ZipStore": store = zarr.ZipStore(path, mode='w') # create the root group root = zarr.group(store, overwrite=True) # add some readme for the user root.attrs[ 'README'] = "The schema adopted in this zarr store may undergo changes in the future" root.attrs['sample_id'] = sample_id # now iterate through list of expected groups and create them for dataset in ZARR_GROUP: root.create_group(dataset, overwrite=True) return root
def create_zarr(self): """ Create a zarr for the contents of `self.data_model`. The grouped structure of this zarr is: root (filename) | - phenom_0 | - phenom_1 | - ... | - dimension_0 | - dimension_1 | - ... | - phenom_n | - ... TODO: add global NetCDF attributes to outermost zarr structure? """ store = zarr.DirectoryStore(self.array_filename) self.group = zarr.group(store=store) # Write zarr datasets for data variables. for domain in self.data_model.domains: domain_vars = self.data_model.domain_varname_mapping[domain] self.create_variable_datasets(domain_vars) # Write zarr datasets for dimension variables. keys = list(self.data_model.domain_varname_mapping.keys()) unique_flat_keys = set([k for domain in keys for k in domain]) self.create_variable_datasets(unique_flat_keys)
def hdf2zarr(source, target): opened = False if isinstance(source, (bytes, str)): hdf5_filename = source hdf5_file = h5py.File(source, "r") opened = True else: hdf5_file = source hdf5_filename = hdf5_file.filename store = zarr.DirectoryStore(target) root = zarr.group(store=store, overwrite=True) def copy(name, obj): if isinstance(obj, h5py.Group): zarr_obj = root.create_group(name) elif isinstance(obj, h5py.Dataset): zarr_obj = root.create_dataset(name, data=obj, chunks=obj.chunks) else: assert False, "Unsupport HDF5 type." zarr_obj.attrs.update(obj.attrs) hdf5_file.visititems(copy) if opened: hdf5_file.close() return root
def main(cfg): # data_preprocessing data_dir = cfg["data_dir"] out_dir = cfg["out_dir"] dat = HubmapDataset(data_dir, out_dir) # db_zarr db = zarr.open_group(store=zarr.DirectoryStore(cfg["zarr_db_dir"]), mode="r") # results dir results_dir = dat.path.out / "normalization" / cfg["version"] if cfg["version"] == "debug": os.makedirs(results_dir, exist_ok=True) else: try: os.makedirs(results_dir, exist_ok=False) except: raise Exception(f"Version {cfg['version']} exists!") # save config dat.jsn_dump(cfg, results_dir / "config.json") # tiles tile_dct = dat.pkl_load(dat.path.out / "tiles" / cfg['tiles_version'] / "tile_dct.pkl") tiles = tile_dct["train_df"] # mean sum = np.zeros(3) N = (tiles["tile"].apply(lambda x: x[1] - x[0]) * tiles["tile"].apply(lambda x: x[3] - x[2])).sum() for _, row in tqdm(tiles.iterrows(), total=len(tiles)): id_ = row["id"] c = row["tile"] slc = np.s_[c[0]:c[1], c[2]:c[3]] img = db[id_]["img"][slc] / 255 sum += img.sum(axis=(0, 1)) mean = sum / N dat.pkl_dump(mean, results_dir / "mean.pkl") # std diff_squared = np.zeros(3) for _, row in tqdm(tiles.iterrows(), total=len(tiles)): id_ = row["id"] c = row["tile"] slc = np.s_[c[0]:c[1], c[2]:c[3]] img = db[id_]["img"][slc] / 255 diff_squared += ((img - mean)**2).sum(axis=(0, 1)) std = np.sqrt(diff_squared / N) dat.pkl_dump(std, results_dir / "std.pkl") print(f"MEAN: {mean}") print(f"STD: {std}")
def init_zarr(sample_id, path, file_format): """Initializes the zarr output Args: sample_id (str): sample or cell id path (str): path to the zarr output fileformat (str): zarr file format [ DirectoryStore, ZipStore] Returns: zarr.hierarchy.Group """ store = None if file_format == "DirectoryStore": store = zarr.DirectoryStore(path) if file_format == "ZipStore": store = zarr.ZipStore(path, mode='w') # create the root group root = zarr.group(store, overwrite=True) # add some readme for the user root.attrs['README'] = ( "The schema adopted in this zarr store may undergo " "changes in the future") root.attrs['sample_id'] = sample_id return root
def create_zarr(path: str, image: np.ndarray, chunk_size: int = 512) -> None: print("Computing pyramid...") pyramid = pyramid_gaussian(image, downscale=2, max_layer=4, multichannel=True) # Printer here seems to break writing the zarr file!? # _dump_pyramid(pyramid) store = zarr.DirectoryStore(path) with zarr.group(store, overwrite=True) as group: series = [] for i, layer in enumerate(pyramid): if UINT8: layer = (255 * layer).astype(np.uint8) max_val = np.amax(layer) print( f"Layer {i} -> {layer.shape} -> {layer.dtype} -> max {max_val}" ) path = "base" if i == 0 else f"L{i}" group.create_dataset(path, data=layer, chunks=(chunk_size, chunk_size, 3)) series.append({"path": path}) multiscales = [{ "name": "pyramid", "datasets": series, "type": "pyramid" }] group.attrs["multiscales"] = multiscales
def _get_zarr_group(store): if store is None: # memory store return None, zarr.group() elif isinstance(store, str): store = zarr.DirectoryStore(store) return store, zarr.open_group(store=store, mode="a")
def __init__( self, params, path: PathType, sync_path: Optional[PathType] = None, simulator=None, ): """Instantiate an iP3 store stored in a directory. Args: params (list of strings or int): List of paramater names. If int use ['z0', 'z1', ...]. path: path to storage directory sync_path: path for synchronization via file locks (files will be stored in the given path). It must differ from path, it must be accessible to all processes working on the store, and the underlying filesystem must support file locking. simulator: simulator object. """ zarr_store = zarr.DirectoryStore(path) sync_path = sync_path or os.path.splitext(path)[0] + ".sync" super().__init__( params=params, zarr_store=zarr_store, sync_path=sync_path, simulator=simulator, )
def enumerate(self, fasta_file_list): """Create pytable objects and fill with kmer data Overwrites any existing files Args: fasta_file_list(list): List of fasta file paths, one per line Returns: None """ with open(fasta_file_list, 'r') as infh: fasta_files = infh.read().splitlines() self.n = len(fasta_files) # Kmer index lookup if os.path.exists(self.leveldb_filepath): shutil.rmtree(self.leveldb_filepath) db = plyvel.DB(self.leveldb_filepath, create_if_missing=True) # Zarr chunked dataframe store = zarr.DirectoryStore(self.zarr_filepath) group = zarr.hierarchy.group(store=store, overwrite=True) za = group.zeros('ohe', shape=(self.n, self.buffer), dtype='u1', chunks=(self.n, self.chunk)) for f in fasta_files: self.count(f, db, za) db.close()
def create_ome_zarr(zarr_directory, dtype="f4"): base = np.tile(data.astronaut(), (4, 4, 1)) gaussian = list(pyramid_gaussian(base, downscale=2, max_layer=3, multichannel=True)) pyramid = [] # convert each level of pyramid into 5D image (t, c, z, y, x) for pixels in gaussian: red = pixels[:, :, 0] green = pixels[:, :, 1] blue = pixels[:, :, 2] # wrap to make 5D: (t, c, z, y, x) pixels = np.array([np.array([red]), np.array([green]), np.array([blue])]) pixels = np.array([pixels]).astype(dtype) pyramid.append(pixels) store = zarr.DirectoryStore(zarr_directory) grp = zarr.group(store, overwrite=True) paths = [] for path, dataset in enumerate(pyramid): grp.create_dataset(str(path), data=pyramid[path]) paths.append({"path": str(path)}) image_data = { "id": 1, "channels": [ { "color": "FF0000", "window": {"start": 0, "end": 1}, "label": "Red", "active": True, }, { "color": "00FF00", "window": {"start": 0, "end": 1}, "label": "Green", "active": True, }, { "color": "0000FF", "window": {"start": 0, "end": 1}, "label": "Blue", "active": True, }, ], "rdefs": { "model": "color", }, } multiscales = [ { "version": "0.1", "datasets": paths, } ] grp.attrs["multiscales"] = multiscales grp.attrs["omero"] = image_data
def create_zarr_store(ds, rootdir, ignore_vars=[], storetype='directory', consolidated=True): """ OBSOLETE Write each variable from a xarray Dataset ds into a new zarr ZipStore under the root directory rootdir, excluding optional variables from ignore_vars. PARAMETERS: =========== ds: xarray.Dataset input dataset rootdir: str root path to the zarr stores ignore_vars: list variables to ignore storetype: str zarr store type (directory, zip) consolidated: logical zarr option to store (default = True) RETURNS: ======== None """ for variable in ds.variables: if variable not in ignore_vars: print(f'writing {variable}') # create a bogus dataset to copy a single variable tmp = _xr.Dataset() tmp[variable] = ds[variable] # update output directory with variable name outputdir = rootdir.replace('<VARNAME>', variable) # create the output directory check = subprocess.check_call(f'mkdir -p {outputdir}', shell=True) exit_code(check) # create a zarr store in write mode store_exists = os.path.exists(f'{outputdir}/{variable}') if storetype == 'directory' and not store_exists: store = _zarr.DirectoryStore(f'{outputdir}/{variable}') # then copy to zarr tmp.to_zarr(store, consolidated=consolidated) elif storetype == 'zip': store = _zarr.ZipStore(f'{outputdir}/{variable}.zip', mode='w') # then copy to zarr tmp.to_zarr(store, mode='w', consolidated=consolidated) # and close store if storetype == 'zip': store.close() tmp.close() return None
def convert( input, output, chunk_size=16 * 1024 * 1024, genome=None, overwrite=False ): input_path, input_ext = splitext(input) output_path, output_ext = splitext(output) print('converting: %s to %s' % (input, output)) if input_ext == '.h5' or input_ext == '.loom': if output_ext == '.zarr': # Convert 10x (HDF5) to Zarr source = h5py.File(input) zarr.tree(source) store = zarr.DirectoryStore(output) dest = zarr.group(store=store, overwrite=overwrite) # following fails if without_attrs=False (the default), possibly related to https://github.com/h5py/h5py/issues/973 zarr.copy_all(source, dest, log=sys.stdout, without_attrs=True) zarr.tree(dest) elif output_ext == '.h5ad': if not genome: keys = list(h5py.File(input).keys()) if len(keys) == 1: genome = keys[0] else: raise Exception( 'Set --genome flag when converting from 10x HDF5 (.h5) to Anndata HDF5 (.h5ad); top-level groups in file %s: %s' % (input, ','.join(keys)) ) adata = read_10x_h5(input, genome=genome) # TODO: respect overwrite flag adata.write(output) elif input_ext == '.h5ad': adata = read_h5ad(input, backed='r') (r, c) = adata.shape chunks = (getsize(input) - 1) / chunk_size + 1 chunk_size = (r - 1) / chunks + 1 if output_ext == '.zarr': print('converting %s (%dx%d) to %s in %d chunks (%d rows each)' % (input, r, c, output, chunks, chunk_size)) # TODO: respect overwrite flag adata.write_zarr( make_store(output), chunks=(chunk_size, c) ) else: raise Exception('Unrecognized output extension: %s' % output_ext) else: raise Exception('Unrecognized input extension: %s' % input_ext)
def load(cls, path: PathType): """Load existing DirectoryStore state into a MemoryStore object.""" memory_store = zarr.MemoryStore() directory_store = zarr.DirectoryStore(path) zarr.convenience.copy_store(source=directory_store, dest=memory_store) group = zarr.group(store=memory_store) xshape = cls._extract_xshape_from_zarr_group(group) zdim = cls._extract_zdim_from_zarr_group(group) return MemoryStore(zdim=zdim, xshape=xshape, store=memory_store)
def make_store(path): m = re.match("^gc?s://", path) if m: return GCSMap(path[len(m.group(0)):], gcs=gcsFileSystem()) if path.startswith("s3://"): s3 = S3FileSystem() return S3Map(path[len("s3://"):], s3=s3) return zarr.DirectoryStore(path)
def _benchmark_load_zarr_datasets(self, zarr_paths): callsets = [] self.benchmark_profiler.start_benchmark( operation_name="Load Zarr Dataset") for zarr_path in zarr_paths: store = zarr.DirectoryStore(zarr_path) callset = zarr.Group(store=store, read_only=True) callsets.append(callset) self.benchmark_profiler.end_benchmark() return callsets
def make_store(path): m = re.match('^gc?s://', path) if m: gcs = gcsfs.GCSFileSystem() return gcsfs.mapping.GCSMap(path[len(m.group(0)):], gcs=gcs) if path.startswith('s3://'): s3 = s3fs.S3FileSystem() return s3fs.mapping.S3Map(path[len('s3://')], s3=s3) return zarr.DirectoryStore(path)
def save(self, path: PathType) -> None: """Save the current state of the MemoryStore to a directory.""" path = Path(path) if path.exists() and not path.is_dir(): raise NotADirectoryError(f"{path} should be a directory") elif path.exists() and not is_empty(path): raise FileExistsError(f"{path} is not empty") else: path.mkdir(parents=True, exist_ok=True) zarr_store = zarr.DirectoryStore(path) zarr.convenience.copy_store(source=self.store, dest=store) return None
def __init__(self, filename, cache_size=512 * (1024**2)): """ An object for accessing rss data from s3 blob storage. Parameters ---------- filename : path to rss data object on disk. """ store = zarr.DirectoryStore(f"{filename}") root = zarr.open(store, mode="r") super().__init__(store, cache_size=cache_size)
def save_multiple_images( array: da.Array, output_file: Path, write_mode: str = "x" ) -> None: """ Calculate and store a Dask array in an HDF5 file without exceeding available memory. Use the Dask distributed scheduler to compute a Dask array and store the resulting values to a data set 'data' in the root group of an HDF5 file. The distributed scheduler is capable of managing worker memory better than the default scheduler. In the latter case, the workers can sometimes demand more than the available amount of memory. Using the distributed scheduler avoids this problem. The distributed scheduler cannot write directly to HDF5 files because h5py.File objects are not serialisable. To work around this issue, the data are first stored to a Zarr DirectoryStore, then copied to the final HDF5 file and the Zarr store deleted. Multithreading is used, as the calculation is assumed to be I/O bound. Args: array: A Dask array to be calculated and stored. output_file: Path to the output HDF5 file. write_mode: HDF5 file opening mode. See :class:`h5py.File`. """ # Set a more generous connection timeout than the default 30s. with dask.config.set( { "distributed.comm.timeouts.connect": "60s", "distributed.comm.timeouts.tcp": "60s", "distributed.deploy.lost-worker-timeout": "60s", "distributed.scheduler.idle-timeout": "600s", "distributed.scheduler.locks.lease-timeout": "60s", } ): intermediate = str(output_file.with_suffix(".zarr")) # Overwrite any pre-existing Zarr storage. Don't compute immediately but # return the Array object so we can compute it with a progress bar. method = {"overwrite": True, "compute": False, "return_stored": True} # Prepare to save the calculated images to the intermediate Zarr store. array = array.to_zarr(intermediate, component="data", **method) # Compute the Array and store the values, using a progress bar. progress(array.persist()) print("\nTransferring the images to the output file.") store = zarr.DirectoryStore(intermediate) with h5py.File(output_file, write_mode) as f: zarr.copy_all(zarr.open(store), f, **Bitshuffle()) # Delete the Zarr store. store.clear()
def write_dataset_zarr(dataset, path, key='images'): """ Given a PyTorch Dataset or array_like, write a Zarr dataset. We assume that the dataset returns either a single image, or a tuple whose first entry is an image. For example, in order to return both an image and a set of labels, the dataset can return those as a pair of torch Tensors. Note that the names of the extra members of the tuple can be overridden with the argument 'extra_keys'. """ try: import zarr, lmdb except ImportError: print( 'Please install the zarr and lmdb libraries to use write_dataset_zarr.' ) raise from .utils import tqdm if not isinstance(key, tuple): # make key a tuple if it's not already key = (key, ) store = zarr.DirectoryStore(path) root = zarr.group(store=store, overwrite=True) # determine size needed for h5 dataset ds0 = dataset[0] if not isinstance(ds0, tuple): ds0 = (ds0, ) # check that the length of the tuple matches args if len(ds0) != len(key): raise Exception(f"Dataset returns tuple with {len(ds0)} entries, " "but only {len(key)} keys given") ds = [] for d, k in zip(ds0, key): dtype = d.dtype if isinstance(d, torch.Tensor): # need a numpy dtype for h5py dtype = d.view(-1)[0].cpu().numpy().dtype sh = d.shape ds.append( root.zeros('/' + k, shape=(len(dataset), *sh), chunks=(1, *sh), dtype=dtype)) for i, di in enumerate(tqdm(dataset)): if not isinstance(di, (tuple, list)): di = [di] for I, dsi in zip(di, ds): if isinstance(I, torch.Tensor): I = I.cpu().numpy() dsi[i, ...] = I
def test_local(self): cube = new_cube(time_periods=10, time_start='2019-01-01', variables=dict(precipitation=0.1, temperature=270.5, soil_moisture=0.2)) cube = chunk_dataset(cube, dict(time=1, lat=90, lon=90), format_name='zarr') cube.to_zarr(self.CUBE_PATH) cube.close() diagnostic_store = DiagnosticStore( zarr.DirectoryStore(self.CUBE_PATH), logging_observer(log_path='local-cube.log')) xr.open_zarr(diagnostic_store)